www.gusucode.com > stats 源码程序 matlab案例代码 > stats/ExamineTheMSEForEachSubTreeExample.m

    %% Examine the MSE for Each Subtree
%%
% Unpruned decision trees tend to overfit.  One way to balance model
% complexity and out-of-sample performance is to prune a tree (or restrict
% its growth) so that in-sample and out-of-sample performance are
% satisfactory.
%%
% Load the |carsmall| data set.  Consider |Displacement|,
% |Horsepower|, and |Weight| as predictors of the response |MPG|.

% Copyright 2015 The MathWorks, Inc.

load carsmall
X = [Displacement Horsepower Weight];
Y = MPG;
%%
% Partition the data into training (50%) and validation (50%) sets.
n = size(X,1);
rng(1) % For reproducibility
idxTrn = false(n,1);
idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices 
idxVal = idxTrn == false;                  % Validation set logical indices
%%
% Grow a regression tree using the training set.
Mdl = fitrtree(X(idxTrn,:),Y(idxTrn));
%%
% View the regression tree.
view(Mdl,'Mode','graph');
%%
% The regression tree has seven pruning levels.  Level 0 is the full,
% unpruned tree (as displayed). Level 7 is just the root node (i.e.,
% no splits).
%%
% Examine the training sample MSE for each subtree (or pruning level)
% excluding the highest level.
m = max(Mdl.PruneList) - 1;
trnLoss = resubLoss(Mdl,'SubTrees',0:m)
%%
%
% * The MSE for the full, unpruned tree is about 6 units.
% * The MSE for the tree pruned to level 1 is about 6.3 units.
% * The MSE for the tree pruned to level 6 (i.e., a stump) is about 14.8
% units.
%
%%
% Examine the validation sample MSE at each level
% excluding the highest level.
valLoss = loss(Mdl,X(idxVal,:),Y(idxVal),'SubTrees',0:m)
%%
%
% * The MSE for the full, unpruned tree (level 0) is about 32.1 units.
% * The MSE for the tree pruned to level 4 is about 26.4 units.
% * The MSE for the tree pruned to level 5 is about 30.0 units.
% * The MSE for the tree pruned to level 6 (i.e., a stump) is about 38.5 
% units.
%
%%
% To balance model complexity and out-of-sample performance, consider 
% pruning |Mdl| to level 4.
pruneMdl = prune(Mdl,'Level',4);
view(pruneMdl,'Mode','graph')