www.gusucode.com > stats 源码程序 matlab案例代码 > stats/TrainAkMeansClusteringAlgorithmExample.m

    %% Train a _k_-Means Clustering Algorithm
% This example shows how to use the _k_ means clustering algorithm to find
% a set number of clusters in data set.  Consider R.A. Fisher's famous iris
% data set. The input matrix contains four features (petal and sepal
% lengths and widths) measured on 50 irises, and the response is one of
% three iris species. For illustration, ignore the sepal
% lengths and widths.  
%%
% Load the data set.
load fisheriris
X = meas(:,3:4);
%%
% Plot the petal lengths and widths.  Inspect the plot to determine the
% number of clusters in the data set.
figure;
plot(X(:,1),X(:,2),'.');
xlabel('Petal length (cm)');
ylabel('Petal width (cm)');
title('{\bf Iris Petal Lengths and Widths}');
%%
% The plot suggests that there might be 2 or 3 clusters.  One cluster
% contains irises that have markedly lower petal lengths and widths.
% Another cluster seems to have a petal lengths and widths in the middle of
% the ranges.  There is possibly one more cluster having higher lengths and
% widths than all other irises, and also higher variability.
%%
% To illustrate how the _k_ means clustering works, train the algorithm by:
%
% * Passing in the data
% * Setting the number of clusters to 2
% * Setting the number of iterations to 1.
%
% The last step implements one iteration of _k_ means clustering. Plot the
% centroids after each iteration.
k = 2;               % Number of clusters
rng(1);              % For reproducibility
Cinit = X(randsample(size(X,1),k),:); % Initial centroid locations

h = 1:4; % Preallocate for plot handle
figure;
h(1) = plot(X(:,1),X(:,2),'.');
hold on;
h(2) = plot(Cinit(:,1),Cinit(:,2),'bo','MarkerSize',15,'LineWidth',2);
epsilon = 999;       % Convergence flag
while epsilon > 0.01
    warning('off','all') % Temporarily disable warnings
    [Idx2,C2] = kmeans(X,2,'MaxIter',1,'start',Cinit); ...
        % One iteration of k means clustering
    warning('on','all')  % Enable warnings
    h(3) = plot(C2(:,1),C2(:,2),'go','MarkerSize',15,'LineWidth',2);
    for jj = 1:k
        line([Cinit(jj,1);C2(jj,1)],[Cinit(jj,2);C2(jj,2)],'Color','k',...
            'LineWidth',2)
    end
    epsilon = sum(diag((Cinit - C2)'*(Cinit - C2)));
    Cinit = C2;
end
h(4) = plot(C2(:,1),C2(:,2),'ro','MarkerSize',15,'LineWidth',2);
xlabel('Petal length (cm)');
ylabel('Petal width (cm)');
title('{\bf Iris Petal Lengths and Widths}');
legend(h,'Data','Start Centroid','Transitional Centroid',...
    'Final Centroid','Location','SouthEast')
hold off;
C2
%%
% |Idx2| is an _n_-by-1 vector of cluster assignments corresponding to the observations.  The rows of |C2| are the coordinates of the final centroids. The final
% centroids appear in the center of both groups of data.
%%
% Color the cluster regions. For illustration, plot the data by
% distinguishing between the iris species (though the response is unknown
% in the context of unsupervised learning).
h = 1:6;
x1 = min(X(:,1)):0.01:max(X(:,1));
x2 = min(X(:,2)):0.01:max(X(:,2));
[x1G,x2G] = meshgrid(x1,x2);
XGrid = [x1G(:),x2G(:)]; % Defines a fine grid on the plot
warning('off','all') 
idx2Region = kmeans(XGrid,2,'MaxIter',1,'start',C2);...
    % Assigns each node in the grid to the closest centroid
warning('on','all')
figure;
h(1:2) = gscatter(XGrid(:,1),XGrid(:,2),idx2Region,...
    [0,0.75,0.75;0.75,0,0.75],'..');
hold on;
h(3:5) = gscatter(X(:,1),X(:,2),species,zeros(3,3),'xd+',8);
hold on;
h(6) = plot(C2(:,1),C2(:,2),'ro','MarkerSize',15,'LineWidth',2);
xlabel('Petal length (cm)');
ylabel('Petal width (cm)');
title('{\bf Iris Petal Lengths and Widths}');
legend(h,'Cluster 1 Region','Cluster 2 Region',...
    '{\it L. setosa}','{\it L. versicolor}','{\it L. virginica}',...
    'Centroid','Location','SouthEast')
hold off;
%%
% The algorithm misclassifies one _L. versicolor_, and all of the _L.
% virginica_ irises.  The number of clusters (_k_) is a _nuisance parameter_,
% and it is best practice to test a few different values before deciding
% which is optimal.
%%
% Train the _k_ means clustering algorithm again, but set the number of
% clusters to 3. Plot the data, the centroids, and the cluster regions.
k = 3;
[Idx3,C3] = kmeans(X,3);

h = 1:6;
x1 = min(X(:,1)):0.01:max(X(:,1));
x2 = min(X(:,2)):0.01:max(X(:,2));
[x1G,x2G] = meshgrid(x1,x2);
XGrid = [x1G(:),x2G(:)];
warning('off','all') 
idx3Region = kmeans(XGrid,3,'MaxIter',1,'start',C3);...
warning('on','all')

figure;
h(1:3) = gscatter(XGrid(:,1),XGrid(:,2),idx3Region,...
    0.75*[0,1,1;1,0,1;1,1,0],'...');
hold on;
h(4:6) = gscatter(X(:,1),X(:,2),species,zeros(3,3),'xd+',8);
hold on;
h(7) = plot(C3(:,1),C3(:,2),'ro','MarkerSize',15,'LineWidth',2);
xlabel('Petal length (cm)');
ylabel('Petal width (cm)');
title('{\bf Iris Petal Lengths and Widths}');
legend(h,'Cluster 1 Region','Cluster 2 Region','Cluster 3 Region',...
    '{\it L. setosa}','{\it L. versicolor}','{\it L. virginica}',...
    'Centroid','Location','SouthEast')
hold off;
%%
% The algorithm seems to identify the three species well for _k_ = 3.