www.gusucode.com > ​用mushrooms数据对模式识别课程讲述的各种模式分类方法matlab源码程序 > pattern-recognition-simulation/KNN_function.m

    function result_class=KNN_function(traing_example_original,test_example,K)

%该函数用来进行KNN分类;
%输入参数:  traing_example_original:训练样本,带标签
%           test_example:测试样本,带标签
%           K:表示K近邻
%输出参数: result_class是一个m*1列向量,包括类别信息

m=size(test_example,1); %获得测试样本的个数
n=size(traing_example_original,1);%获得训练样本的个数

%初始化距离矩阵,oushi_distance行表示测试样本,列表示训练样本,行列交点表示两者的距离
oushi_distance=zeros(m,n);
%计算一个测试样本与每个训练样本的欧氏距离
for i=1:m
    for j=1:n
         oushi_distance(i,j)=sum(((test_example(i,2:end)-traing_example_original(j,2:end)).^2)');
    end
end

for i=1:m
    temp=oushi_distance(i,:);
    temp=sort(temp);
    Kmin_oushi_distance(1,1:K)=temp(1,1:K);%找到最小的K个近邻
    for p = 1:K
        index_column = find(oushi_distance(i,:)==Kmin_oushi_distance(1,p));%在oushi_distance(i,:)中找出最小的距离对应的列号
        if p == 1
            k_traing_example_original = traing_example_original(index_column(1,1:end),:);
        else
            k_traing_example_original = [k_traing_example_original;traing_example_original(index_column(1,1:end),:)];
        end
    end
    k_traing_example_original = k_traing_example_original(1:K,:);
    index_column_class1 = find(k_traing_example_original(:,1)==1);
    k1 = size(index_column_class1,1);
    if k1>=K-k1
        result_class(i,1)=1;
    else
        result_class(i,1)=2;
    end
end