www.gusucode.com > stats 源码程序 matlab案例代码 > stats/NearestNeighborsIllustrationExample.m

    %% Train a Nearest Neighbors Algorithm
% This example shows how to use the _k-_ nearest neighbors algorithm as a binary classifier. For
% illustration, consider R.A. Fisher's famous iris data set.  The
% input matrix contains four features measured on 50 irises, and the
% response is the corresponding iris species. Suppose that you want to
% train a nearest neighbors algorithm to predict an iris species given a
% new set of features.
%%
% Load the iris data set.  There are three iris species and four features (the lengths and widths of the sepals and petals).
% For illustration, remove species _L. setosa,_ the sepal lengths and
% widths, and do not partition the data.

% Copyright 2015 The MathWorks, Inc.

load fisheriris
X = meas(~strcmp(species,'setosa'),3:4);
species = nominal(species(~strcmp(species,'setosa')));
n = length(species);
%%
% Create a binary variable |y|.  If sampled iris _i_ is _L.
% virginica,_ then |y = 1|. |y = 0| otherwise.
y = zeros(n,1);
y(species == 'virginica') = 1;
%%
% For illustration, choose _k_ = 8, and find the inputs closest to the third sample iris.
% Plot the data and circle the 8 nearest neighors to the third iris.
iris = 3;
idx = knnsearch(X,X(iris,:),'k',8);

figure;
gscatter(X(:,1),X(:,2),species,[1,0,0;0,0.5,1]);
hold on
plot(X(iris,1),X(iris,2),'kx','MarkerSize',15,'LineWidth',2);
plot(X(idx,1),X(idx,2),'ko','MarkerSize',10);
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')
title('{\bf 8-Nearest Neighbors of Iris 3}')
legend('{\it L. versicolor}','{\it L. virginica}','Third Iris',...
    'Nearest Neighbors','Location','Best')
hold off
%%
% In the plot, the cross corresponds to the third iris, and the circled
% points correspond to the third iris's nearest neighbors.  There are 7
% (rather than 8) circled points because two iris have the same petal
% length and width.
%%
% Find the average of the responses corresponding to the 8 nearest
% neighbors of the third iris.
y3Bar = mean(y(idx))
%%
% |y3Bar = 0.25| indicates that 2 of the 8 irises in the vacinity of the
% third iris (in terms of petal dimensions) are _L. virginica._ This
% results also represents the posterior probability that the third iris is
% an _L. virginica._
%%
% The _k-_ nearest neighbors algorithm calculates the mean response for each
% iris.  |fitcknn| in Statsitics Toolbox(TM) performs _k-_
% nearest neighbors automatically, and supports multiple categories.
%%
% Train the _k-_ nearest neighbors algorithm using the data.  It is good
% practice to standardize the predictor data.
EstMdl = fitcknn(X,species,'NumNeighbors',8,'Standardize',1)
%%
% |EstMdl| is the trained _k-_nearest neighbors algorithm.  The Command
% Window displays some of its properties for verification.
%%
% Use the trained algorithm |EstMdl| to predict the proportion of _L.
% virginica_ irises neighboring the third sampled iris.
[~,yHat3] = predict(EstMdl,X(3,:))
%%
% |yHat3| indicates that the posterior probability that the third iris is
% an _L. versicolor_ is 0.75, and the posterior probability that it is an
% _L. virginica_ is 0.25.  This result matches the hand calculation above.
%%
% Plot the decision boundary, the line that distinguishes between the two
% iris species based on its features.
x1 = min(X(:,1)):0.01:max(X(:,1));
x2 = min(X(:,2)):0.01:max(X(:,2));
[x1G,x2G] = meshgrid(x1,x2);
XGrid = [x1G(:),x2G(:)];
pred = predict(EstMdl,XGrid); % Predicted irises

figure
gscatter(XGrid(:,1),XGrid(:,2),pred,[1,0,0;0,0.5,1]);
hold on
plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),...
    'kx','MarkerSize',8,'LineWidth',1.5)
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')
title('{\bf 8-Nearest Neighbors Algorithm Decision Boundary}')
legend('{\it L. versicolor} Region','{\it L. virginica} Region',...
    'Sampled {\it L. versicolor}','Sampled {\it L. virginica}','Location','Best')
axis tight
hold off
%%
% The partition between the red and blue regions is the decision boundary.
% If you change the number of neighbors, _k,_ then the boundary changes.
%%
% Re-train the algorithm using _k_= 1 and _k_ = 20.
EstMdl1 = fitcknn(X,species); % k = 1 by default
pred1 = predict(EstMdl1,XGrid);

EstMdl20 = fitcknn(X,species,...
    'NumNeighbors',20); % k = 20
pred20 = predict(EstMdl20,XGrid);

figure
subplot(2,2,1)
gscatter(XGrid(:,1),XGrid(:,2),pred1,[1,0,0;0,0.5,1]);
hold on
plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),...
    'kx','MarkerSize',8,'LineWidth',1.5)
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')
title('{\bf 1-Nearest Neighbors}')
legend('{\it L. versicolor} Region','{\it L. virginica} Region',...
    'Sampled {\it L. versicolor}',...
    'Sampled {\it L. virginica}','Location','EastOutside')
axis tight
hold off
subplot(2,2,4)
gscatter(XGrid(:,1),XGrid(:,2),pred20,[1,0,0;0,0.5,1]);
hold on
plot(X(y == 0,1),X(y == 0,2),'ko',X(y == 1,1),X(y == 1,2),...
    'kx','MarkerSize',8,'LineWidth',1.5)
xlabel('Petal length (cm)')
ylabel('Petal width (cm)')
title('{\bf 20-Nearest Neighbors}')
legend('hide')
axis tight
hold off
%%
% The decision boundary appears to linearize as _k_ increases.  This is
% because the algorithm down-weights the importance of each input with
% increasing _k_.  When _k_ = 1, notice that the algorithm is able to
% correctly predict the species of almost all of the iris.  When _k_ = 20,
% the algorithm does not do as well, it has a higher _missclassification_
% rate within the testing set.
%%
% When you train a nearest neighbor algorithm, be sure to partition the data into testing
% and validation sets. Check that the missclassification rates from both sets
% are satisfactory.  You may have to adjust _k_ to find the right
% algorithm.