www.gusucode.com > classification_matlab_toolbox分类方法工具箱源码程序 > code/Classification_toolbox/Competitive_learning.m

    function [features, targets, label, W] = Competitive_learning(train_features, train_targets, params, region, plot_on)

% Perform preprocessing using a competitive learning network
% Inputs:
% 	features	- Train features
%	targets	    - Train targets
%	params	    - [Number of partitions, learning rate]
%	region	    - Decision region vector: [-x x -y y number_of_points]
%   plot_on     - Plot while performing processing?
%
% Outputs
%	features		- New features
%	targets			- New targets
%	label			- The labels given for each of the original features
%   W               - Weights matrice

max_iter       = 1000;
[c, r]		   = size(train_features);
[N, eta]			= process_params(params);
decay          = 0.99;

%Preprocessing:
% x_i <- {x_i, 1}
x              = [train_features ; ones(1,r)];
%x_i <- x_i./||x_i||
x              = x ./ (ones(c+1,1) * sqrt(sum(x.^2)));

%Initialize the W's
i              = randperm(r);
W              = x(:,i(1:N));

for i = 1:max_iter,
    %Randomally order the patterns
    order = randperm(r);
    change= 0;
    
    for k = 1:r,
        J = W'*x(:,order(k));
        j = find(J == max(J));
        
        old_W   = W(:,j);
        
        %W_j <- W_j + eta*x
        W(:,j)  = W(:,j) + eta*x(:,order(k));
        
        %W_j <- W_j/||W_j||
        W(:,j)  = W(:,j) / sqrt(sum(W(:,j).^2));
        
        change = change + sum(abs(W(:,j) - old_W));
        
        if (plot_on == 1),
            %Assign each of the features to a center
            dist        = W'*x;
            [m, label]  = max(dist);
            centers     = zeros(c,N);
            for i = 1:N,
                in = find(label == i);
                if ~isempty(in)
                    centers(:,i) = mean(x(1:2,find(label==i))')';
                else
                    centers(:,i) = nan;
                end
            end
            plot_process(centers)
        end

    end

    eta = eta * decay;
    
    if (change/r < 1e-4),
        break
    end
    
end

if (i == max_iter),
   disp(['Maximum iteration (' num2str(max_iter) ') reached']);
else
    disp(['Finished after ' num2str(i) ' iterations.'])
end

%Assign each of the features to a center
dist        = W'*x;
[m, label]  = max(dist);
features     = zeros(c,N);
for i = 1:N,
    in = find(label == i);
    if ~isempty(in)
        features(:,i) = mean(x(1:2,find(label==i))')';
    else
        features(:,i) = nan;
    end
end

targets = zeros(1,N);
if (N > 1),
	for i = 1:N,
   	if (length(train_targets(find(label == i))) > 0),
      	targets(i) = (sum(train_targets(find(label == i)))/length(train_targets(find(label == i))) > .5);
   	end
	end
else
   %There is only one center
   targets = (sum(train_targets)/length(train_targets) > .5);
end