www.gusucode.com > stats 源码程序 matlab案例代码 > stats/CreatePredictiveEnsembleUsingCrossValidationExample.m

    %% Create Predictive Ensemble Using Cross-Validation
% One way to create an ensemble of boosted regression trees that has
% satisfactory predictive performance is to tune the decision
% tree-complexity level using cross-validation. While searching for an
% optimal complexity level, tune the learning rate to minimize the number
% of learning cycles.
%%
% Load the |carsmall| data set.  Choose the number of cylinders, volume
% displaced by the cylinders, horsepower, and weight as predictors of fuel
% economy.
load carsmall
Tbl = table(Cylinders,Displacement,Horsepower,Weight,MPG);
%%
% To search for the optimal tree-complexity level:
%
% # Cross-validate a set of ensembles.  Exponentially increase the
% tree-complexity level for subsequent ensembles from decision stump (one
% split) to at most _n_ - 1 splits.  _n_ is the sample size.  Also, vary
% the learning rate for each ensemble between 0.1 to 1.
% # Estimate the cross-validated mean-squared error (MSE) for each
% ensemble.
% # For tree-complexity level $j$, $j=1...J$, compare the cumulative,
% cross-validated MSE of the ensembles by plotting them against number of
% learning cycles.  Plot separate curves for each learning rate on the same
% figure.
% # Choose the curve that achieves the minimal MSE, and note the
% corresponding learning cycle and learning rate.
%
%%
% Cross-validate a deep regression tree and a stump. Because the data
% contain missing values, use surrogate splits. These regression trees
% serve as benchmarks.
rng(1); % For reproducibility
MdlDeep = fitrtree(Tbl,'MPG','CrossVal','on','MergeLeaves','off',...
    'MinParentSize',1,'Surrogate','on');
MdlStump = fitrtree(Tbl,'MPG','MaxNumSplits',1,'CrossVal','on',...
    'Surrogate','on');
%% 
% Cross-validate an ensemble of 150 boosted regression trees using 5-fold
% cross-validation. Using a tree template:
%
% * Vary the maximum number of splits using the values in the sequence
% $\{2^0, 2^1,...,2^m\}$. _m_ is such that $2^m$ is no greater than
% _n_ - 1. 
% * Turn on surrogate splits.
%
% For each variant, adjust the learning rate using each value in the set
% {0.1, 0.25, 0.5, 1};
n = size(Tbl,1);
m = floor(log2(n - 1));
learnRate = [0.1 0.25 0.5 1];
numLR = numel(learnRate);
maxNumSplits = 2.^(0:m);
numMNS = numel(maxNumSplits);
numTrees = 150;
Mdl = cell(numMNS,numLR);

for k = 1:numLR;
    for j = 1:numMNS;
        t = templateTree('MaxNumSplits',maxNumSplits(j),'Surrogate','on');
        Mdl{j,k} = fitrensemble(Tbl,'MPG','NumLearningCycles',numTrees,...
            'Learners',t,'KFold',5,'LearnRate',learnRate(k));
    end;
end;
%%
% Estimate the cumulative, cross-validated MSE of each ensemble.
kflAll = @(x)kfoldLoss(x,'Mode','cumulative');
errorCell = cellfun(kflAll,Mdl,'Uniform',false);
error = reshape(cell2mat(errorCell),[numTrees numel(maxNumSplits) numel(learnRate)]);
errorDeep = kfoldLoss(MdlDeep);
errorStump = kfoldLoss(MdlStump);
%%
% Plot how the cross-validated MSE behaves as the number
% of trees in the ensemble increases.  Plot the curves with respect to
% learning rate on the same plot, and plot separate plots for varying
% tree-complexity levels. Choose a subset of tree complexity levels to
% plot.
mnsPlot = [1 round(numel(maxNumSplits)/2) numel(maxNumSplits)];
figure;
for k = 1:3;
    subplot(2,2,k);
    plot(squeeze(error(:,mnsPlot(k),:)),'LineWidth',2);
    axis tight;
    hold on;
    h = gca;
    plot(h.XLim,[errorDeep errorDeep],'-.b','LineWidth',2);
    plot(h.XLim,[errorStump errorStump],'-.r','LineWidth',2);
    plot(h.XLim,min(min(error(:,mnsPlot(k),:))).*[1 1],'--k');
    h.YLim = [10 50];    
    xlabel 'Number of trees';
    ylabel 'Cross-validated MSE';
    title(sprintf('MaxNumSplits = %0.3g', maxNumSplits(mnsPlot(k))));
    hold off;
end;
hL = legend([cellstr(num2str(learnRate','Learning Rate = %0.2f'));...
        'Deep Tree';'Stump';'Min. MSE']);
hL.Position(1) = 0.6;  
%%
% Each curve contains a minimum cross-validated MSE occurring at the
% optimal number of trees in the ensemble.
%%
% Identify the maximum number of splits, number of trees, and learning rate
% that yields the lowest MSE overall.
[minErr,minErrIdxLin] = min(error(:));
[idxNumTrees,idxMNS,idxLR] = ind2sub(size(error),minErrIdxLin);

fprintf('\nMin. MSE = %0.5f',minErr)
fprintf('\nOptimal Parameter Values:\nNum. Trees = %d',idxNumTrees);
fprintf('\nMaxNumSplits = %d\nLearning Rate = %0.2f\n',...
    maxNumSplits(idxMNS),learnRate(idxLR))
%%
% Create a predictive ensemble based on the optimal hyperparameters and the
% entire training set.
tFinal = templateTree('MaxNumSplits',maxNumSplits(idxMNS),'Surrogate','on');
MdlFinal = fitrensemble(Tbl,'MPG','NumLearningCycles',idxNumTrees,...
    'Learners',tFinal,'LearnRate',learnRate(idxLR))
%%
% |MdlFinal| is a |RegressionEnsemble|.  To predict the fuel economy of a
% car given its number of cylinders, volume displaced by the cylinders,
% horsepower, and weight, pass the predictor data and |MdlFinal| to
% |predict|.