www.gusucode.com > stats 源码程序 matlab案例代码 > stats/TrainAkMeansClusteringAlgorithmExample.m
%% Train a _k_-Means Clustering Algorithm % This example shows how to use the _k_ means clustering algorithm to find % a set number of clusters in data set. Consider R.A. Fisher's famous iris % data set. The input matrix contains four features (petal and sepal % lengths and widths) measured on 50 irises, and the response is one of % three iris species. For illustration, ignore the sepal % lengths and widths. %% % Load the data set. load fisheriris X = meas(:,3:4); %% % Plot the petal lengths and widths. Inspect the plot to determine the % number of clusters in the data set. figure; plot(X(:,1),X(:,2),'.'); xlabel('Petal length (cm)'); ylabel('Petal width (cm)'); title('{\bf Iris Petal Lengths and Widths}'); %% % The plot suggests that there might be 2 or 3 clusters. One cluster % contains irises that have markedly lower petal lengths and widths. % Another cluster seems to have a petal lengths and widths in the middle of % the ranges. There is possibly one more cluster having higher lengths and % widths than all other irises, and also higher variability. %% % To illustrate how the _k_ means clustering works, train the algorithm by: % % * Passing in the data % * Setting the number of clusters to 2 % * Setting the number of iterations to 1. % % The last step implements one iteration of _k_ means clustering. Plot the % centroids after each iteration. k = 2; % Number of clusters rng(1); % For reproducibility Cinit = X(randsample(size(X,1),k),:); % Initial centroid locations h = 1:4; % Preallocate for plot handle figure; h(1) = plot(X(:,1),X(:,2),'.'); hold on; h(2) = plot(Cinit(:,1),Cinit(:,2),'bo','MarkerSize',15,'LineWidth',2); epsilon = 999; % Convergence flag while epsilon > 0.01 warning('off','all') % Temporarily disable warnings [Idx2,C2] = kmeans(X,2,'MaxIter',1,'start',Cinit); ... % One iteration of k means clustering warning('on','all') % Enable warnings h(3) = plot(C2(:,1),C2(:,2),'go','MarkerSize',15,'LineWidth',2); for jj = 1:k line([Cinit(jj,1);C2(jj,1)],[Cinit(jj,2);C2(jj,2)],'Color','k',... 'LineWidth',2) end epsilon = sum(diag((Cinit - C2)'*(Cinit - C2))); Cinit = C2; end h(4) = plot(C2(:,1),C2(:,2),'ro','MarkerSize',15,'LineWidth',2); xlabel('Petal length (cm)'); ylabel('Petal width (cm)'); title('{\bf Iris Petal Lengths and Widths}'); legend(h,'Data','Start Centroid','Transitional Centroid',... 'Final Centroid','Location','SouthEast') hold off; C2 %% % |Idx2| is an _n_-by-1 vector of cluster assignments corresponding to the observations. The rows of |C2| are the coordinates of the final centroids. The final % centroids appear in the center of both groups of data. %% % Color the cluster regions. For illustration, plot the data by % distinguishing between the iris species (though the response is unknown % in the context of unsupervised learning). h = 1:6; x1 = min(X(:,1)):0.01:max(X(:,1)); x2 = min(X(:,2)):0.01:max(X(:,2)); [x1G,x2G] = meshgrid(x1,x2); XGrid = [x1G(:),x2G(:)]; % Defines a fine grid on the plot warning('off','all') idx2Region = kmeans(XGrid,2,'MaxIter',1,'start',C2);... % Assigns each node in the grid to the closest centroid warning('on','all') figure; h(1:2) = gscatter(XGrid(:,1),XGrid(:,2),idx2Region,... [0,0.75,0.75;0.75,0,0.75],'..'); hold on; h(3:5) = gscatter(X(:,1),X(:,2),species,zeros(3,3),'xd+',8); hold on; h(6) = plot(C2(:,1),C2(:,2),'ro','MarkerSize',15,'LineWidth',2); xlabel('Petal length (cm)'); ylabel('Petal width (cm)'); title('{\bf Iris Petal Lengths and Widths}'); legend(h,'Cluster 1 Region','Cluster 2 Region',... '{\it L. setosa}','{\it L. versicolor}','{\it L. virginica}',... 'Centroid','Location','SouthEast') hold off; %% % The algorithm misclassifies one _L. versicolor_, and all of the _L. % virginica_ irises. The number of clusters (_k_) is a _nuisance parameter_, % and it is best practice to test a few different values before deciding % which is optimal. %% % Train the _k_ means clustering algorithm again, but set the number of % clusters to 3. Plot the data, the centroids, and the cluster regions. k = 3; [Idx3,C3] = kmeans(X,3); h = 1:6; x1 = min(X(:,1)):0.01:max(X(:,1)); x2 = min(X(:,2)):0.01:max(X(:,2)); [x1G,x2G] = meshgrid(x1,x2); XGrid = [x1G(:),x2G(:)]; warning('off','all') idx3Region = kmeans(XGrid,3,'MaxIter',1,'start',C3);... warning('on','all') figure; h(1:3) = gscatter(XGrid(:,1),XGrid(:,2),idx3Region,... 0.75*[0,1,1;1,0,1;1,1,0],'...'); hold on; h(4:6) = gscatter(X(:,1),X(:,2),species,zeros(3,3),'xd+',8); hold on; h(7) = plot(C3(:,1),C3(:,2),'ro','MarkerSize',15,'LineWidth',2); xlabel('Petal length (cm)'); ylabel('Petal width (cm)'); title('{\bf Iris Petal Lengths and Widths}'); legend(h,'Cluster 1 Region','Cluster 2 Region','Cluster 3 Region',... '{\it L. setosa}','{\it L. versicolor}','{\it L. virginica}',... 'Centroid','Location','SouthEast') hold off; %% % The algorithm seems to identify the three species well for _k_ = 3.