www.gusucode.com > stats 源码程序 matlab案例代码 > stats/ControlTheTreeDepthExample.m
%% Find the Optimal Number of Splits and Trees for an Ensemble % You can control the depth of the trees in an ensemble of decision trees. % You can also control the tree depth in an ECOC model containing decision tree binary learners using the % |MaxNumSplits|, |MinLeafSize|, or |MinParentSize| name-value pair % parameters. % % * When bagging decision trees, |fitensemble| grows deep % decision trees by default. You can grow shallower trees to % reduce model complexity or computation time. % * When boosting decision trees, fitensemble grows stumps (a tree with one % split) by default. You can grow deeper trees for better % accuracy. % %% % Load the |carsmall| data set. Specify the variables |Acceleration|, % |Displacement|, |Horsepower|, and |Weight| as predictors, and |MPG| as % the response. load carsmall X = [Acceleration Displacement Horsepower Weight]; Y = MPG; %% % The default values of the tree depth controllers for boosting % regression trees are: % % * |1| for |MaxNumSplits|. This option grows stumps. % * |5| for |MinLeafSize| % * |10| for |MinParentSize| % %% % To search for the optimal number of splits: % % # Train a set of ensembles. Exponentially increase the maximum number of splits for % subsequent ensembles from stump to at most _n_ - 1 splits. Also, % decrease the learning rate for each ensemble from 1 to 0.1. % # Cross validate the ensembles. % # Estimate the cross-validated mean-squared error (MSE) for each ensemble. % # Compare the cross-validated MSEs. The ensemble with % the lowest one performs the best, and indicates the optimal maximum % number of splits, number of trees, and learning rate for the data set. % %% % Grow and cross validate a deep regression tree and a stump. Specify % to use surrogate splits because the data contain missing values. These % serve as benchmarks. MdlDeep = fitrtree(X,Y,'CrossVal','on','MergeLeaves','off',... 'MinParentSize',1,'Surrogate','on'); MdlStump = fitrtree(X,Y,'MaxNumSplits',1,'CrossVal','on','Surrogate','on'); %% % Train the boosting ensembles using 150 regression trees. Cross validate % the ensemble using 5-fold cross validation. Vary the maximum number of % splits using the values in the sequence $\{2^0, 2^1,...,2^m\}$, where _m_ % is such that $2^m$ is no greater than _n_ - 1. For each variant, adjust the learning rate to % each value in the set {0.1, 0.25, 0.5, 1}; n = size(X,1); m = floor(log2(n - 1)); lr = [0.1 0.25 0.5 1]; maxNumSplits = 2.^(0:m); numTrees = 150; Mdl = cell(numel(maxNumSplits),numel(lr)); rng(1); % For reproducibility for k = 1:numel(lr); for j = 1:numel(maxNumSplits); t = templateTree('MaxNumSplits',maxNumSplits(j),'Surrogate','on'); Mdl{j,k} = fitensemble(X,Y,'LSBoost',numTrees,t,... 'Type','regression','KFold',5,'LearnRate',lr(k)); end; end; %% % Compute the cross-validated MSE for each ensemble. kflAll = @(x)kfoldLoss(x,'Mode','cumulative'); errorCell = cellfun(kflAll,Mdl,'Uniform',false); error = reshape(cell2mat(errorCell),[numTrees numel(maxNumSplits) numel(lr)]); errorDeep = kfoldLoss(MdlDeep); errorStump = kfoldLoss(MdlStump); %% % Plot how the cross-validated MSE behaves as the number of trees in the % ensemble increases for a few of the ensembles, the deep tree, and the % stump. Plot the curves with respect to learning rate in the same plot, % and plot separate plots for varying tree complexities. Choose a subset of % tree complexity levels. mnsPlot = [1 round(numel(maxNumSplits)/2) numel(maxNumSplits)]; figure; for k = 1:3; subplot(2,2,k); plot(squeeze(error(:,mnsPlot(k),:)),'LineWidth',2); axis tight; hold on; h = gca; plot(h.XLim,[errorDeep errorDeep],'-.b','LineWidth',2); plot(h.XLim,[errorStump errorStump],'-.r','LineWidth',2); plot(h.XLim,min(min(error(:,mnsPlot(k),:))).*[1 1],'--k'); h.YLim = [10 50]; xlabel 'Number of trees'; ylabel 'Cross-validated MSE'; title(sprintf('MaxNumSplits = %0.3g', maxNumSplits(mnsPlot(k)))); hold off; end; hL = legend([cellstr(num2str(lr','Learning Rate = %0.2f'));... 'Deep Tree';'Stump';'Min. MSE']); hL.Position(1) = 0.6; %% % Each curve contains a minimum cross-validated MSE occuring at the optimal % number of trees in the ensemble. %% % Identify the maximum number of splits, number of trees, and learning rate % that yields the lowest MSE overall. [minErr,minErrIdxLin] = min(error(:)); [idxNumTrees,idxMNS,idxLR] = ind2sub(size(error),minErrIdxLin); fprintf('\nMin. MSE = %0.5f',minErr) fprintf('\nOptimal Parameter Values:\nNum. Trees = %d',idxNumTrees); fprintf('\nMaxNumSplits = %d\nLearning Rate = %0.2f\n',... maxNumSplits(idxMNS),lr(idxLR)) %% % For a different approach to optimizing this ensemble, see % <docid:stats_ug.bvdx7il-1>.