www.gusucode.com > stats 源码程序 matlab案例代码 > stats/CrossValidateLinearClassificationModel1Example.m
%% Find Good Lasso Penalty Using Cross-Validation % To determine a good lasso-penalty strength for a linear classification % model that uses a logistic regression learner, implement 5-fold % cross-validation. % %% % Load the NLP data set. load nlpdata %% % |X| is a sparse matrix of predictor data, and |Y| is a categorical vector % of class labels. There are more than two classes in the data. %% % The models should identify whether the word counts in a web page are % from the Statistics and Machine Learning Toolbox(TM) documentation. So, % identify the labels that correspond to the Statistics and Machine % Learning Toolbox(TM) documentation web pages. Ystats = Y == 'stats'; %% % Create a set of 11 logarithmically-spaced regularization strengths from % $10^{-6}$ through $10^{-0.5}$. Lambda = logspace(-6,-0.5,11); %% % Cross-validate the models. To increase execution speed, transpose the % predictor data and specify that the observations are in columns. Estimate % the coefficients using SpaRSA. Lower the tolerance on the gradient of % the objective function to |1e-8|. % X = X'; rng(10); % For reproducibility CVMdl = fitclinear(X,Ystats,'ObservationsIn','columns','KFold',5,... 'Learner','logistic','Solver','sparsa','Regularization','lasso',... 'Lambda',Lambda,'GradientTolerance',1e-8) numCLModels = numel(CVMdl.Trained) %% % |CVMdl| is a |ClassificationPartitionedLinear| model. Because |fitclinear| % implements 5-fold cross-validation, |CVMdl| contains 5 % |ClassificationLinear| models that the software trains on each fold. %% % Display the first trained linear classification model. Mdl1 = CVMdl.Trained{1} %% % |Mdl1| is a |ClassificationLinear| model object. |fitclinear| constructed % |Mdl1| by training on the first four folds. Because |Lambda| is a % sequence of regularization strengths, you can think of |Mdl1| as 11 % models, one for each regularization strength in |Lambda|. %% % Estimate the cross-validated classification error. ce = kfoldLoss(CVMdl); %% % Because there are 11 regularization strengths, |ce| is a 1-by-11 vector % of classification error rates. %% % Higher values of |Lambda| lead to predictor variable sparsity, which is a % good quality of a classifier. For each regularization strength, train a % linear classification model using the entire data set and the same % options as when you cross-validated the models. Determine the number of % nonzero coefficients per model. Mdl = fitclinear(X,Ystats,'ObservationsIn','columns',... 'Learner','logistic','Solver','sparsa','Regularization','lasso',... 'Lambda',Lambda,'GradientTolerance',1e-8); numNZCoeff = sum(Mdl.Beta~=0); %% % In the same figure, plot the cross-validated, classification error rates % and frequency of nonzero coefficients for each regularization strength. % Plot all variables on the log scale. figure; [h,hL1,hL2] = plotyy(log10(Lambda),log10(ce),... log10(Lambda),log10(numNZCoeff)); hL1.Marker = 'o'; hL2.Marker = 'o'; ylabel(h(1),'log_{10} classification error') ylabel(h(2),'log_{10} nonzero-coefficient frequency') xlabel('log_{10} Lambda') title('Test-Sample Statistics') hold off %% % Choose the index of the regularization strength that balances predictor % variable sparsity and low classification error. In this case, a value % between $10^{-4}$ to $10^{-1}$ should suffice. idxFinal = 7; %% % Select the model from |Mdl| with the chosen regularization strength. MdlFinal = selectModels(Mdl,idxFinal); %% % |MdlFinal| is a |ClassificationLinear| model containing one % regularization strength. To estimate labels for new observations, pass % |MdlFinal| and the new data to |predict|.