www.gusucode.com > 超全的模式识别matlab源码程序 > code/Store_Grabbag.m

    function test_targets = Store_Grabbag(train_patterns, train_targets, test_patterns, Knn)

% Classify using the store-grabbag algorithm (an improvement on the nearest neighbor)
% Inputs:
% 	train_patterns	- Train patterns
%	train_targets	- Train targets
%   test_patterns   - Test  patterns
%	Knn		        - Number of nearest neighbors
%
% Outputs
%	test_targets	- Predicted targets

L		= length(train_patterns);

%Placing first sample in STORE
Store_patterns(:,1) = train_patterns(:,1);
Store_targets       = train_targets(1);
Grabbag_targets     = [];
Grabbag_patterns    = [];

for i = 2:L,
   target = Knn_Rule(train_patterns(:,i), Store_patterns, Store_targets, Knn);
   if target == train_targets(i)
      Grabbag_patterns = [Grabbag_patterns , train_patterns(:,i)];  
      Grabbag_targets = [Grabbag_targets train_targets(i)];
   else
      Store_patterns = [Store_patterns, train_patterns(:,i)];
      Store_targets  = [Store_targets train_targets(i)];
   end 
end      

New_Grabbag_patterns = Grabbag_patterns;

while (Grabbag_patterns ~= New_Grabbag_patterns)
   Grabbag_patterns = New_Grabbag_patterns;
   New_Grabbag_targets = [];
   for i = 1:length(Grabbag_patterns),
      target = Knn_Rule(Grabbag_patterns(:,i), Store_patterns, Store_targets);
   	if target == train_targets(i)
      	New_Grabbag_patterns = [New_Grabbag_patterns, train_patterns(:,i)];  
      	New_Grabbag_targets  = [New_Grabbag_targets train_targets(i)];
   	else
      	Store_patterns = [Store_patterns, train_patterns(:,i)];
      	Store_targets  = [Store_targets , train_targets(i)];
      end
   end
end
    
      
disp(['Calling Nearest Neighbor algorithm']);
test_targets = Nearest_Neighbor(Store_patterns, Store_targets, test_patterns, Knn);

%END

function target = Knn_Rule(Sample, Store_patterns, Store_targets, Knn)
%Classify a sample using the NN rule

for i = 1:length(Store_targets),
   %Find the k nearest neighbours
   dist(i) = sqrt((Sample(1)-Store_patterns(1,i)).^2+(Sample(2)-Store_patterns(2,i)).^2);  
end
[sorted_dist, indices] = sort(dist);

if length(Store_targets) <= Knn
   k_nearest = Store_targets;
else
   k_nearest = Store_targets(indices(1:Knn));
end

target = (sum(k_nearest) > Knn/2);