www.gusucode.com > 数据挖掘工具箱 - mitmatlab源码程序 > mitmatlab\Interactive_Learning.m

    function D = Interactive_Learning(train_features, train_targets, params, region);

% Classify using nearest neighbors and interactive learning
% Inputs:
% 	features- Train features
%	targets	- Train targets
%	params  - [Number of query points, Weight (Relative weight of the new point in relation to the old data points)] 
%
% Outputs
%	D			- Decision sufrace
%
% In this implementation, we train a nearest neighbor classifier, and ask for interactive assistance in problematic areas

[Npoints , Weight] = process_params(params);

Weight  = round(Weight*size(train_features,2));
NN      = 3;

hm      = findobj('Tag', 'Messages'); 
N       = region(5);
x       = linspace(region(1),region(2),N);
y       = linspace(region(3),region(4),N);
D       = NearestNeighbor(train_features, train_targets, NN, region);

for i = 1:Npoints,
    %Find the most ambiguous point in the decision region
    ambig   = abs(D-0.5);
    [iy,ix] = find(ambig == min(min(ambig)));
    ix      = ix(1); iy = iy(1);
    
    %Query the user for the label of this point
    set(hm, 'String', 'Press the left mouse button to label the red point as class 0 (Blue) or right button as 1 (Green)')
    h       = plot(x(ix), y(iy), 'rd', 'LineWidth', 2);
    [t1, t2, button] = ginput(1);
    delete(h)
    drawnow
    new_label = button == 3;
    
    %Add this point to the data with the supplied label
    new_features   = [x(ix); y(iy)]*ones(1,Weight) + randn(2,Weight)*0.001;
    train_features = [train_features, new_features];
    train_targets  = [train_targets, new_label*ones(1,Weight)];
    
    %Build a new decision region        
    D       = NearestNeighbor(train_features, train_targets, NN, region);
end
    
D = D > .5;
set(hm, 'String', '')

function D = NearestNeighbor(features, targets, NN, region)
%Find the nearest neighbor classifier according to the relative distances

M       = size(features,2);
N       = region(5);
x       = linspace(region(1),region(2),N);
y       = linspace(region(3),region(4),N);
D       = zeros(N);

if (M < NN),
   error('You specified more neighbors than there are points.')
end

if (NN < 3)
    error('Number of nearest neighbors must be at least 3 for this function to work')
end

y_dist	= (ones(N,1) * features(2,:) - y'*ones(1,M)).^2;

for i = 1:N,
    if (i/50 == floor(i/50)),
        disp(['Finished ' num2str(i) ' lines out of ' num2str(N) ' lines.'])
    end
    
    x_dist = ones(N,1)  * (features(1,:)-x(i)).^2;
    dist   = abs(x_dist + y_dist);   
    [sorted_dist, indices] = sort(dist');
    Tnearest = targets(indices(1:NN,:));
    Tdist    = sorted_dist(1:NN,:);
    D(:,i)   = (sum(Tnearest.*Tdist)./sum(Tdist))';  
    
end