www.gusucode.com > classification_matlab_toolbox分类方法工具箱源码程序 > code/Classification_toolbox/Stumps.m

    function [D, w] = Stumps(train_features, train_targets, params, region)

% Classify using the least-squares algorithm
% Inputs:
% 	features- Train features
%	targets	- Train targets
%	weights	- Unused (Except if weighted stumps is needed)
%	region	- Decision region vector: [-x x -y y number_of_points]
%
% Outputs
%	D			- Decision sufrace
%	w			- Decision surface parameters

train_one  = find(train_targets == 1);
train_zero = find(train_targets == 0);

if (length(params)-1 == length(train_targets)), 
    p = params(1:end-1);
else
    p = ones(size(train_targets));   
end

dim        = size(train_features,1);
w          = zeros(1,dim);
err        = zeros(1,dim);
direction  = zeros(1,dim);

for i = 1:dim,
    %For each dimension, find the point where a stump gives the minimal error
    
    %First, sort the working dimension 
    [data(i,:), indices] = sort(train_features(i,:));
    temp_targets    = train_targets(indices);
    temp_p		 	  = p(indices);
    
    decision        = cumsum(temp_p .* temp_targets)/length(train_one) - cumsum(temp_p .* (~temp_targets))/length(train_zero);
    [err(i),W]      = max(abs(decision));
    w(i)            = data(i,W);
    direction(i)    = sign(decision(W));
end

[m, min_dim] = max(err);
indices      = find(~ismember(1:dim,min_dim));
w(indices)   = 0;

N    		 = region(5);

if (dim == 2),
    %Find decision region (For 2-D data)
    x		= linspace (region(1),region(2),N);
    y		= linspace (region(3),region(4),N);
    D       = zeros(N);
    
    if (w(1)~=0),
        ix = find(data(1,:)==w(1)); ix = ix(1);
        if ix == length(data(1,:)),
            xt = region(2);
        else
            xt = (data(1,ix+1) + data(1,ix)) / 2; 
        end
        [m, indice] = min(abs(x - xt));
        if (direction(1) < 0),
            D(:,indice+1:N) = 1;
        else
            D(:,1:indice) = 1;
        end
    else
        iy = find(data(2,:)==w(2)); iy = iy(1);
        if iy == length(data(2,:)),
            yt = region(4);
        else
            yt = (data(2,iy+1) + data(2,iy)) / 2; 
        end
        [m, indice] = min(abs(y - yt));
        if (direction(2) < 0),
            D(indice+1:N,:) = 1;
        else
            D(1:indice,:) = 1;
        end
    end     
else
    D = zeros(N);
    disp('No decision region calculated because the data has more than two dimensions')
end