www.gusucode.com > stats 源码程序 matlab案例代码 > stats/ExamineTheMSEForEachSubTreeExample.m
%% Examine the MSE for Each Subtree %% % Unpruned decision trees tend to overfit. One way to balance model % complexity and out-of-sample performance is to prune a tree (or restrict % its growth) so that in-sample and out-of-sample performance are % satisfactory. %% % Load the |carsmall| data set. Consider |Displacement|, % |Horsepower|, and |Weight| as predictors of the response |MPG|. % Copyright 2015 The MathWorks, Inc. load carsmall X = [Displacement Horsepower Weight]; Y = MPG; %% % Partition the data into training (50%) and validation (50%) sets. n = size(X,1); rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices %% % Grow a regression tree using the training set. Mdl = fitrtree(X(idxTrn,:),Y(idxTrn)); %% % View the regression tree. view(Mdl,'Mode','graph'); %% % The regression tree has seven pruning levels. Level 0 is the full, % unpruned tree (as displayed). Level 7 is just the root node (i.e., % no splits). %% % Examine the training sample MSE for each subtree (or pruning level) % excluding the highest level. m = max(Mdl.PruneList) - 1; trnLoss = resubLoss(Mdl,'SubTrees',0:m) %% % % * The MSE for the full, unpruned tree is about 6 units. % * The MSE for the tree pruned to level 1 is about 6.3 units. % * The MSE for the tree pruned to level 6 (i.e., a stump) is about 14.8 % units. % %% % Examine the validation sample MSE at each level % excluding the highest level. valLoss = loss(Mdl,X(idxVal,:),Y(idxVal),'SubTrees',0:m) %% % % * The MSE for the full, unpruned tree (level 0) is about 32.1 units. % * The MSE for the tree pruned to level 4 is about 26.4 units. % * The MSE for the tree pruned to level 5 is about 30.0 units. % * The MSE for the tree pruned to level 6 (i.e., a stump) is about 38.5 % units. % %% % To balance model complexity and out-of-sample performance, consider % pruning |Mdl| to level 4. pruneMdl = prune(Mdl,'Level',4); view(pruneMdl,'Mode','graph')