www.gusucode.com > stats 源码程序 matlab案例代码 > stats/TestEnsembleQualityExample.m

    %% Test Ensemble Quality
% This example uses a bagged ensemble so it can use all three methods of
% evaluating ensemble quality.
%
%% 
% Generate an artificial dataset with 20 predictors. Each entry is a
% random number from 0 to 1. The initial classification is $Y = 1$ if
% $X_1 + X_2 + X_3 + X_4 + X_5 > 2.5$ and $Y = 0$ otherwise.

% Copyright 2015 The MathWorks, Inc.

rng(1,'twister') % for reproducibility
X = rand(2000,20);
Y = sum(X(:,1:5),2) > 2.5;
%%
% In addition, to add noise to the results, randomly switch 10% of the
% classifications:
idx = randsample(2000,200);
Y(idx) = ~Y(idx);
%% Independent Test Set
% Create independent training and test sets of data. Use 70% of the data
% for a training set by calling |cvpartition| using the |holdout| option:
cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);
%%
% Create a bagged classification ensemble of 200 trees from the training
% data:
bag = fitensemble(Xtrain,Ytrain,'Bag',200,'Tree',...
    'Type','Classification')
%%
% Plot the loss (misclassification) of the test data as a function of the
% number of trained trees in the ensemble:
figure;
plot(loss(bag,Xtest,Ytest,'mode','cumulative'));
xlabel('Number of trees');
ylabel('Test classification error');
%% Cross Validation
% Generate a five-fold cross-validated bagged ensemble:
cv = fitensemble(X,Y,'Bag',200,'Tree',...
    'type','classification','kfold',5)
%% 
% Examine the cross-validation loss as a function of the number of trees in
% the ensemble:
figure;
plot(loss(bag,Xtest,Ytest,'mode','cumulative'));
hold on;
plot(kfoldLoss(cv,'mode','cumulative'),'r.');
hold off;
xlabel('Number of trees');
ylabel('Classification error');
legend('Test','Cross-validation','Location','NE');
%%
% Cross validating gives comparable estimates to those of the independent
% set.
%% Out-of-Bag Estimates
% Generate the loss curve for out-of-bag estimates, and plot it along with
% the other curves:
figure;
plot(loss(bag,Xtest,Ytest,'mode','cumulative'));
hold on;
plot(kfoldLoss(cv,'mode','cumulative'),'r.');
plot(oobLoss(bag,'mode','cumulative'),'k--');
hold off;
xlabel('Number of trees');
ylabel('Classification error');
legend('Test','Cross-validation','Out of bag','Location','NE');
%%
% The out-of-bag estimates are again comparable to those of the other
% methods.