www.gusucode.com > stats 源码程序 matlab案例代码 > stats/OptimizeSVMClassifierUsingBayesianOptimizationExample.m
%% Optimize a Cross-Validated SVM Classifier Using Bayesian Optimization % This example shows how to optimize an SVM classification. The % classification works on locations of points from a Gaussian mixture % model. In _The Elements of Statistical Learning_, Hastie, Tibshirani, and % Friedman (2009), page 17 describes the model. The model begins with % generating 10 base points for a "green" class, distributed as 2-D % independent normals with mean (1,0) and unit variance. It also generates % 10 base points for a "red" class, distributed as 2-D independent normals % with mean (0,1) and unit variance. For each class (green and red), % generate 100 random points as follows: % % # Choose a base point _m_ of the appropriate color uniformly at random. % # Generate an independent random point with 2-D normal distribution % with mean _m_ and variance I/5, where I is the 2-by-2 identity matrix. In % this example, use a variance I/50 to show the advantage of optimization % more clearly. % % After generating 100 green and 100 red points, classify them using % |fitcsvm|. Then use |bayesopt| to optimize the parameters of the resulting % SVM model with respect to cross validation. %% Generate the Points and Classifier % % Generate the 10 base points for each class. rng default grnpop = mvnrnd([1,0],eye(2),10); redpop = mvnrnd([0,1],eye(2),10); %% % View the base points. plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off %% % Since some red base points are close to green base points, it can be % difficult to classify the data points based on location alone. %% % Generate the 100 data points of each class. redpts = zeros(100,2);grnpts = redpts; for i = 1:100 grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02); redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end %% % View the data points. figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off %% Prepare Data For Classification % Put the data into one matrix, and make a vector |grp| that labels the % class of each point. cdata = [grnpts;redpts]; grp = ones(200,1); % Green label 1, red label -1 grp(101:200) = -1; %% Prepare Cross-Validation % Set up a partition for cross-validation. This step fixes the train and % test sets that the optimization uses at each step. c = cvpartition(200,'KFold',10); %% Prepare Variables for Bayesian Optimization % Set up a function that takes an input |z = [rbf_sigma,boxconstraint]| % and returns the cross-validation loss value of |z|. Take the components % of |z| as positive, log-transformed variables between |1e-5| and |1e5|. % Choose a wide range, because you don't know which values are likely to be % good. sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log'); box = optimizableVariable('box',[1e-5,1e5],'Transform','log'); %% Objective Function % This function handle computes the cross-validation loss at parameters % |[sigma,box]|. For details, see <docid:stats_ug.bsu1r2a-1>. % % |bayesopt| passes the variable |z| to the objective function as a one-row % table. minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,... 'KernelFunction','rbf','BoxConstraint',z.box,... 'KernelScale',z.sigma)); %% Optimize Classifier % Search for the best parameters |[sigma,box]| using |bayesopt|. For % reproducibility, choose the |'expected-improvement-plus'| acquisition % function. The default acquisition function depends on run time, and so % can give varying results. results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,... 'AcquisitionFunctionName','expected-improvement-plus') %% % Use the results to train a new, optimized SVM classifier. z(1) = results.XAtMinObjective.sigma; z(2) = results.XAtMinObjective.box; SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',... 'KernelScale',z(1),'BoxConstraint',z(2)); %% % Plot the classification boundaries. % To visualize the support vector classifier, predict scores over a grid. d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),... min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(SVMModel,xGrid); h = nan(3,1); % Preallocation figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast'); axis equal hold off %% Evaluate Accuracy on New Data % Generate and classify some new data points. grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2)); newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); grpData(11:20) = -1; % red = -1 v = predict(SVMModel,newData); g = nan(7,1); figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',... '+1 (classified)','Support Vectors'},'Location','Southeast'); axis equal hold off %% % See which new data points are correctly classified. Circle the correctly % classified points in red, and the incorrectly classified points in black. mydiff = (v == grpData); % Classified correctly figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); for ii = mydiff % Plot red squares around correct pts h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end for ii = not(mydiff) % Plot black squares around incorrect pts h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','-1 (classified)',... '+1 (classified)','Support Vectors','Correctly Classified',... 'Misclassified'},'Location','Southeast'); hold off