www.gusucode.com > stats 源码程序 matlab案例代码 > stats/FindGoodLassoPenaltyUsingMarginsExample.m
%% Find Good Lasso Penalty Using Margins % To determine a good lasso-penalty strength for a linear classification % model that uses a logistic regression learner, compare distributions of % test-sample margins. % %% % Load the NLP data set. Preprocess the data as in % <docid:stats_ug.bu620gc>. load nlpdata Ystats = Y == 'stats'; X = X'; Partition = cvpartition(Ystats,'Holdout',0.30); testIdx = test(Partition); XTest = X(:,testIdx); YTest = Ystats(testIdx); %% % Create a set of 11 logarithmically-spaced regularization strengths from % $10^{-8}$ through $10^{1}$. Lambda = logspace(-8,1,11); %% % Train binary, linear classification models that use each of the % regularization strengths. Solve the objective function using SpaRSA. % Lower the tolerance on the gradient of the objective function to |1e-8|. % rng(10); % For reproducibility CVMdl = fitclinear(X,Ystats,'ObservationsIn','columns',... 'CVPartition',Partition,'Learner','logistic','Solver','sparsa',... 'Regularization','lasso','Lambda',Lambda,'GradientTolerance',1e-8) %% % Extract the trained linear classification model. Mdl = CVMdl.Trained{1} %% % |Mdl| is a |ClassificationLinear| model object. Because |Lambda| is a % sequence of regularization strengths, you can think of |Mdl| as 11 % models, one for each regularization strength in |Lambda|. %% % Estimate the test-sample margins. m = margin(Mdl,X(:,testIdx),Ystats(testIdx),'ObservationsIn','columns'); size(m) %% % Because there are 11 regularization strengths, |m| has 11 columns. %% % Plot the test-sample margins for each regularization strength. Because % logistic regression scores are in [0,1], margins are in [-1,1]. Rescale % the margins to help identify the regularization strength that maximizes % the margins over the grid. figure; boxplot(10000.^m) ylabel('Exponentiated test-sample margins') xlabel('Lambda indices') %% % Several values of |Lambda| yield margin distributions that are compacted % near $10000^1$. Higher values of lambda lead to predictor variable % sparsity, which is a good quality of a classifier. %% % Choose the regularization strength that occurs just before % the centers of the margin distributions start decreasing. LambdaFinal = Lambda(5); %% % Train a linear classification model using the entire data set and specify % the desired regularization strength. MdlFinal = fitclinear(X,Ystats,'ObservationsIn','columns',... 'Learner','logistic','Solver','sparsa','Regularization','lasso',... 'Lambda',LambdaFinal); %% % To estimate labels for new observations, pass |MdlFinal| and the new data % to |predict|.