www.gusucode.com > stats 源码程序 matlab案例代码 > stats/ControlTheTreeDepthExample.m

    %% Find the Optimal Number of Splits and Trees for an Ensemble
% You can control the depth of the trees in an ensemble of decision trees.
%  You can also control the tree depth in an ECOC model containing decision tree binary learners using the
% |MaxNumSplits|, |MinLeafSize|, or |MinParentSize| name-value pair
% parameters.  
%
% * When bagging decision trees, |fitensemble| grows deep
% decision trees by default.  You can grow shallower trees to
% reduce model complexity or computation time.
% * When boosting decision trees, fitensemble grows stumps (a tree with one
% split) by default. You can  grow deeper trees for better
% accuracy.
%
%%
% Load the |carsmall| data set.  Specify the variables |Acceleration|,
% |Displacement|, |Horsepower|, and |Weight| as predictors, and |MPG| as
% the response.
load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;
%%
% The default values of the tree depth controllers for boosting 
% regression trees are:
%
% * |1| for |MaxNumSplits|.  This option grows stumps.
% * |5| for |MinLeafSize|
% * |10| for |MinParentSize|
%
%%
% To search for the optimal number of splits:
%
% # Train a set of ensembles.  Exponentially increase the maximum number of splits for
% subsequent ensembles from stump to at most _n_ - 1 splits.  Also,
% decrease the learning rate for each ensemble from 1 to 0.1.
% # Cross validate the ensembles.
% # Estimate the cross-validated mean-squared error (MSE) for each ensemble.
% # Compare the cross-validated MSEs.  The ensemble with
% the lowest one performs the best, and indicates the optimal maximum
% number of splits, number of trees, and learning rate for the data set.
%
%%
% Grow and cross validate a deep regression tree and a stump.  Specify
% to use surrogate splits because the data contain missing values. These
% serve as benchmarks.
MdlDeep = fitrtree(X,Y,'CrossVal','on','MergeLeaves','off',...
    'MinParentSize',1,'Surrogate','on');
MdlStump = fitrtree(X,Y,'MaxNumSplits',1,'CrossVal','on','Surrogate','on');
%% 
% Train the boosting ensembles using 150 regression trees.  Cross validate
% the ensemble using 5-fold cross validation.  Vary the maximum number of
% splits using the values in the sequence $\{2^0, 2^1,...,2^m\}$, where _m_
% is such that $2^m$ is no greater than _n_ - 1. For each variant, adjust the learning rate to
% each value in the set {0.1, 0.25, 0.5, 1};
n = size(X,1);
m = floor(log2(n - 1));
lr = [0.1 0.25 0.5 1];
maxNumSplits = 2.^(0:m);
numTrees = 150;
Mdl = cell(numel(maxNumSplits),numel(lr));
rng(1); % For reproducibility
for k = 1:numel(lr);
    for j = 1:numel(maxNumSplits);
        t = templateTree('MaxNumSplits',maxNumSplits(j),'Surrogate','on');
        Mdl{j,k} = fitensemble(X,Y,'LSBoost',numTrees,t,...
            'Type','regression','KFold',5,'LearnRate',lr(k));
    end;
end;
%%
% Compute the cross-validated MSE for each ensemble.
kflAll = @(x)kfoldLoss(x,'Mode','cumulative');
errorCell = cellfun(kflAll,Mdl,'Uniform',false);
error = reshape(cell2mat(errorCell),[numTrees numel(maxNumSplits) numel(lr)]);
errorDeep = kfoldLoss(MdlDeep);
errorStump = kfoldLoss(MdlStump);
%%
% Plot how the cross-validated MSE behaves as the number of trees in the
% ensemble increases for a few of the ensembles, the deep tree, and the
% stump.  Plot the curves with respect to learning rate in the same plot,
% and plot separate plots for varying tree complexities. Choose a subset of
% tree complexity levels.
mnsPlot = [1 round(numel(maxNumSplits)/2) numel(maxNumSplits)];
figure;
for k = 1:3;
    subplot(2,2,k);
    plot(squeeze(error(:,mnsPlot(k),:)),'LineWidth',2);
    axis tight;
    hold on;
    h = gca;
    plot(h.XLim,[errorDeep errorDeep],'-.b','LineWidth',2);
    plot(h.XLim,[errorStump errorStump],'-.r','LineWidth',2);
    plot(h.XLim,min(min(error(:,mnsPlot(k),:))).*[1 1],'--k');
    h.YLim = [10 50];    
    xlabel 'Number of trees';
    ylabel 'Cross-validated MSE';
    title(sprintf('MaxNumSplits = %0.3g', maxNumSplits(mnsPlot(k))));
    hold off;
end;
hL = legend([cellstr(num2str(lr','Learning Rate = %0.2f'));...
        'Deep Tree';'Stump';'Min. MSE']);
hL.Position(1) = 0.6;  
%%
% Each curve contains a minimum cross-validated MSE occuring at the optimal
% number of trees in the ensemble. 
%%
% Identify the maximum number of splits, number of trees, and learning rate
% that yields the lowest MSE overall.
[minErr,minErrIdxLin] = min(error(:));
[idxNumTrees,idxMNS,idxLR] = ind2sub(size(error),minErrIdxLin);

fprintf('\nMin. MSE = %0.5f',minErr)
fprintf('\nOptimal Parameter Values:\nNum. Trees = %d',idxNumTrees);
fprintf('\nMaxNumSplits = %d\nLearning Rate = %0.2f\n',...
    maxNumSplits(idxMNS),lr(idxLR))
%%
% For a different approach to optimizing this ensemble, see
% <docid:stats_ug.bvdx7il-1>.