www.gusucode.com > stats 源码程序 matlab案例代码 > stats/NearestNeighborsIllustrationExample.m
%% Train a Nearest Neighbors Algorithm % This example shows how to use the _k-_ nearest neighbors algorithm as a binary classifier. For % illustration, consider R.A. Fisher's famous iris data set. The % input matrix contains four features measured on 50 irises, and the % response is the corresponding iris species. Suppose that you want to % train a nearest neighbors algorithm to predict an iris species given a % new set of features. %% % Load the iris data set. There are three iris species and four features (the lengths and widths of the sepals and petals). % For illustration, remove species _L. setosa,_ the sepal lengths and % widths, and do not partition the data. % Copyright 2015 The MathWorks, Inc. load fisheriris X = meas(~strcmp(species,'setosa'),3:4); species = nominal(species(~strcmp(species,'setosa'))); n = length(species); %% % Create a binary variable |y|. If sampled iris _i_ is _L. % virginica,_ then |y = 1|. |y = 0| otherwise. y = zeros(n,1); y(species == 'virginica') = 1; %% % For illustration, choose _k_ = 8, and find the inputs closest to the third sample iris. % Plot the data and circle the 8 nearest neighors to the third iris. iris = 3; idx = knnsearch(X,X(iris,:),'k',8); figure; gscatter(X(:,1),X(:,2),species,[1,0,0;0,0.5,1]); hold on plot(X(iris,1),X(iris,2),'kx','MarkerSize',15,'LineWidth',2); plot(X(idx,1),X(idx,2),'ko','MarkerSize',10); xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 8-Nearest Neighbors of Iris 3}') legend('{\it L. versicolor}','{\it L. virginica}','Third Iris',... 'Nearest Neighbors','Location','Best') hold off %% % In the plot, the cross corresponds to the third iris, and the circled % points correspond to the third iris's nearest neighbors. There are 7 % (rather than 8) circled points because two iris have the same petal % length and width. %% % Find the average of the responses corresponding to the 8 nearest % neighbors of the third iris. y3Bar = mean(y(idx)) %% % |y3Bar = 0.25| indicates that 2 of the 8 irises in the vacinity of the % third iris (in terms of petal dimensions) are _L. virginica._ This % results also represents the posterior probability that the third iris is % an _L. virginica._ %% % The _k-_ nearest neighbors algorithm calculates the mean response for each % iris. |fitcknn| in Statsitics Toolbox(TM) performs _k-_ % nearest neighbors automatically, and supports multiple categories. %% % Train the _k-_ nearest neighbors algorithm using the data. It is good % practice to standardize the predictor data. EstMdl = fitcknn(X,species,'NumNeighbors',8,'Standardize',1) %% % |EstMdl| is the trained _k-_nearest neighbors algorithm. The Command % Window displays some of its properties for verification. %% % Use the trained algorithm |EstMdl| to predict the proportion of _L. % virginica_ irises neighboring the third sampled iris. [~,yHat3] = predict(EstMdl,X(3,:)) %% % |yHat3| indicates that the posterior probability that the third iris is % an _L. versicolor_ is 0.75, and the posterior probability that it is an % _L. virginica_ is 0.25. This result matches the hand calculation above. %% % Plot the decision boundary, the line that distinguishes between the two % iris species based on its features. x1 = min(X(:,1)):0.01:max(X(:,1)); x2 = min(X(:,2)):0.01:max(X(:,2)); [x1G,x2G] = meshgrid(x1,x2); XGrid = [x1G(:),x2G(:)]; pred = predict(EstMdl,XGrid); % Predicted irises figure gscatter(XGrid(:,1),XGrid(:,2),pred,[1,0,0;0,0.5,1]); hold on plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),... 'kx','MarkerSize',8,'LineWidth',1.5) xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 8-Nearest Neighbors Algorithm Decision Boundary}') legend('{\it L. versicolor} Region','{\it L. virginica} Region',... 'Sampled {\it L. versicolor}','Sampled {\it L. virginica}','Location','Best') axis tight hold off %% % The partition between the red and blue regions is the decision boundary. % If you change the number of neighbors, _k,_ then the boundary changes. %% % Re-train the algorithm using _k_= 1 and _k_ = 20. EstMdl1 = fitcknn(X,species); % k = 1 by default pred1 = predict(EstMdl1,XGrid); EstMdl20 = fitcknn(X,species,... 'NumNeighbors',20); % k = 20 pred20 = predict(EstMdl20,XGrid); figure subplot(2,2,1) gscatter(XGrid(:,1),XGrid(:,2),pred1,[1,0,0;0,0.5,1]); hold on plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),... 'kx','MarkerSize',8,'LineWidth',1.5) xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 1-Nearest Neighbors}') legend('{\it L. versicolor} Region','{\it L. virginica} Region',... 'Sampled {\it L. versicolor}',... 'Sampled {\it L. virginica}','Location','EastOutside') axis tight hold off subplot(2,2,4) gscatter(XGrid(:,1),XGrid(:,2),pred20,[1,0,0;0,0.5,1]); hold on plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),... 'kx','MarkerSize',8,'LineWidth',1.5) xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 20-Nearest Neighbors}') legend('hide') axis tight hold off %% % The decision boundary appears to linearize as _k_ increases. This is % because the algorithm down-weights the importance of each input with % increasing _k_. When _k_ = 1, notice that the algorithm is able to % correctly predict the species of almost all of the iris. When _k_ = 20, % the algorithm does not do as well, it has a higher _missclassification_ % rate within the testing set. %% % When you train a nearest neighbor algorithm, be sure to partition the data into testing % and validation sets. Check that the missclassification rates from both sets % are satisfactory. You may have to adjust _k_ to find the right % algorithm.