www.gusucode.com > 超全的模式识别matlab源码程序 > code/Stumps.m
function [test_targets, w] = Stumps(train_patterns, train_targets, test_patterns, params) % Classify using simple stumps algorithm % Inputs: % train_patterns - Train patterns % train_targets - Train targets % test_patterns - Test patterns % weights - Unused (Except if weighted stumps is needed) % % Outputs % test_targets - Predicted labels % w - Decision surface parameters % % NOTE: Works for only two classes! train_one = find(train_targets == max(train_targets)); train_zero = find(train_targets == min(train_targets)); if (length(params)-1 == length(train_targets)), p = params(1:end-1); else p = ones(size(train_targets)); end dim = size(train_patterns,1); w = zeros(1,dim); err = zeros(1,dim); direction = zeros(1,dim); for i = 1:dim, %For each dimension, find the point where a stump gives the minimal error %First, sort the working dimension [data(i,:), indices] = sort(train_patterns(i,:)); temp_targets = train_targets(indices); temp_p = p(indices); decision = cumsum(temp_p .* temp_targets)/length(train_one) - cumsum(temp_p .* (~temp_targets))/length(train_zero); [err(i),W] = max(abs(decision)); w(i) = data(i,W); direction(i) = sign(decision(W)); end [m, min_dim] = max(err); indices = find(~ismember(1:dim,min_dim)); w(indices) = 0; if (direction(min_dim) > 0) indices = find(test_patterns(min_dim,:) < w(min_dim)); else indices = find(test_patterns(min_dim,:) > w(min_dim)); end test_targets = zeros(1, size(test_patterns,2)); test_targets(indices) = 1;