www.gusucode.com > 超全的模式识别matlab源码程序 > code/RBF_Network.m
function [test_targets, mu, Wo] = RBF_Network(train_patterns, train_targets, test_patterns, Nh) % Classify using a radial basis function network algorithm % Inputs: % train_patterns - Train patterns % train_targets - Train targets % test_patterns - Test patterns % Nh - Number of hidden units % % Outputs % test_targets - Predicted targets % mu - Hidden unit locations % Wo - Output unit weights [Ni, M] = size(train_patterns); %First, find locations for the hidden unit centers using k-means Npoints = 100; [mu, center_targets, label] = k_means(train_patterns, train_targets, Nh, 0); %Remove bad centers ok = find(isfinite(mean(mu))); mu = mu(:, ok); Nh = length(ok); %Variance of the gaussians dist = zeros(Nh); for i=1:Nh, dist(i,:) = sqrt(sum((mu(:,i)*ones(1,Nh) - mu).^2)); end max_dist = max(max(dist)); sigma = max_dist/sqrt(2*Nh); %Compute the activation for each pattern at each center Phi = zeros(Nh, M); for i = 1:Nh, Phi(i,:) = 1/(2*pi*sigma^2)^(Ni/2)*exp(-sum((train_patterns-mu(:,i)*ones(1,M)).^2)/(2*sigma^2)); end %Now, find the hidden to output weights Wo = (pinv(Phi)'*(train_targets*2-1)')'; %Classify test patterns N = size(test_patterns, 2); Phi = zeros(Nh, N); for i = 1:Nh, Phi(i,:) = 1/(2*pi*sigma^2)^(Ni/2)*exp(-sum((test_patterns-mu(:,i)*ones(1,N)).^2)/(2*sigma^2)); end test_targets = Wo * Phi; %If there are only two classes, collapse them if (length(unique(train_targets)) == 2) test_targets = test_targets > 0.5; end