www.gusucode.com > knn k近邻算法,可选择欧式距离或者曼哈顿距离matlab源码程序 > code13/knn.m
function result = knn(trainX, trainclassY,testZ, k, type) % Classify using the Nearest neighbor algorithm % Inputs: % trainX - Train sample matrix, n*d, n points, each d dimentions % trainclassY - class of trainX, n*1 % testZ - Test sample matrix , N*d % k - Number of nearest neighbors % type - specified measure distance: 2norm, 1norm etc. % Outputs: % result - class of testZ, N*1 % 判断trainX和testZ的样本点维数是否相同 if size(trainX,2) ~= size(testZ,2) error ('trainX and testZ must have same column dimensions !') % 维数d应该相同 end % 判断k近邻是否可取 n = length(trainclassY); % 测试样本点个数 if ( n < k) error('You specified more neighbors than existed points.') end % 选择的近邻数不大于样本点数 class = unique(trainclassY); % unique(x)表示列出数组x不重复的元素,并按降序排列,这里可以排出类别数目 N = size(testZ, 1); % testZ的行数,即测试集的样本点数 result = zeros(N, 1); % 初始化result矩阵,N*1列向量标出testZ的类别 % 确定使用的度量距离,若未指定,默认为2norm if nargin < 5 % nargin表示函数输入变量的个数,<5即type未输入 type = '2norm'; % L2范数,即用欧氏距离 end % 按照所选的度量距离,对testZ的N个点逐个进行k近邻分类 switch type % 对type分类 case '2norm' % 使用L2范数,欧氏距离 for i = 1:N dist = sum((trainX - ones(n,1)*testZ(i,:)).^2,2); % dist 表示第i个测试点分别与n个训练样本点之间的欧式距离的平方 [m, indices] = sort(dist); % 按升序排列距离 histclass = hist(trainclassY(indices(1:k)), class); % 取前k个最短距离对应的点所属的类别,按照class进行直方图统计,histclass为class中各类出现的次数 [c, best] = max(histclass); % c取所属类别出现最多的那个类别的次数,best标记出c在histclass中对应index,即为class中index result(i) = class(best); % testZ的第i个点取best在class中对应的类别 end case '1norm' % 使用L1范数,即 Manhatan 距离 for i = 1:N dist = sum(abs(trainX - ones(n,1)*testZ(i,:)),2); % dist 表示第i个测试点分别与n个训练样本点之间的L1距离 [m, indices] = sort(dist); histclass = hist(trainclassY(indices(1:k)), class); [c, best] = max(histclass); result(i) = class(best); end %case 'match' 使用match(匹配)距离,相同元素越多,越匹配,距离越大 % for i = 1:N % dist = sum(trainX == ones(n,1)*testZ(i,:),2); 若相同取1,否则取0 % [m, indices] = sort(dist,2); 按降序排列,距离最大表示最匹配 % histclass = hist(trainclassY(indices(1:k)), class); % [c, best] = max(histclass); % result(i) = class(best); % end otherwise error('Unknown measure function'); end