www.gusucode.com > classification_matlab_toolbox分类方法工具箱源码程序 > code/Classification_toolbox/multialgorithms_commands.m
function multialgorithms_commands(command) %This function processes events from the multi-algorithm GUI screen switch(command) case {'MoveLeft','MoveRight'} %Move algorithms between the list boxes if strcmp(command,'MoveLeft'), hFrom = findobj('Tag','lstAllAlgorithms'); hTo = findobj('Tag','lstChosenAlgorithms'); else hTo = findobj('Tag','lstAllAlgorithms'); hFrom = findobj('Tag','lstChosenAlgorithms'); end %Find the selected algorithm and remove it from the 'From list' val = get(hFrom,'Value'); algorithms = get(hFrom,'String'); algorithm = algorithms(val,:); if (isempty(deblank(algorithm))) break end newlist = 1:size(algorithms,1); newlist(val) = 0; newlist = newlist(find(newlist ~=0)); set(hFrom, 'Value', 1); if ~isempty(newlist), set(hFrom, 'String', algorithms(newlist,:)); else set(hFrom, 'String', ' '); end %Put the new algorithm in the 'To list' algorithms = get(hTo,'String'); L = max(size(algorithms,2),size(algorithm,2)); if ((isempty(deblank(algorithms(1,:)))) & (size(algorithms,1) == 1)) newalgorithms = algorithm; else newalgorithms = zeros(size(algorithms,1)+1,L); newalgorithms(1:size(algorithms,1),:) = algorithms; newalgorithms(size(algorithms,1)+1,:) = algorithm; end set(hTo, 'String', char(newalgorithms)) case {'Compare', 'Predict'} Npoints = 100; hFigure = gcf; hm = findobj('Tag', 'Messages'); set(hm,'String',''); %Do some error checking if evalin('base', '~exist(''targets'')') set(hm,'String','No targets on workspace. Please load targets.') break end if evalin('base', '~exist(''features'')') set(hm,'String','No features on workspace. Please load features.') break end features = evalin('base','features'); targets = evalin('base','targets'); if (evalin('base', 'exist(''distribution_parameters'')')), distribution_parameters = evalin('base', 'distribution_parameters'); end %Find the region for the grid [region,x,y] = calculate_region(features, [zeros(1,4) Npoints]); %Find which algorithms will be used hAlgorithms = findobj('Tag','lstChosenAlgorithms'); algorithms = get(hAlgorithms,'String'); if ((isempty(deblank(algorithms(1,:)))) & (size(algorithms,1) == 1)) set(hm,'String','Please select at least one algorithm.') break end Nalgorithms = size(algorithms,1); All_algorithms = read_algorithms('Classification.txt'); for i = 1:Nalgorithms, index = strmatch(deblank(algorithms(i,:)),char(All_algorithms(:).Name),'exact'); if ~isempty(index), Chosen_algorithms(i).Name = deblank(algorithms(i,:)); if isempty(strmatch('N',All_algorithms(index).Field)), Chosen_algorithms(i).Parameter = char(inputdlg(['Enter ' All_algorithms(index).Caption], All_algorithms(index).Name, 1, cellstr(All_algorithms(index).Default))); else Chosen_algorithms(i).Parameter = ''; end end end if strcmp(command, 'Compare'), %Comapre the algorithms error_method_val = get(findobj('Tag', 'popErrorEstimation'),'Value'); error_method_str = get(findobj('Tag', 'popErrorEstimation'),'String'); error_method = char(error_method_str(error_method_val)); h = findobj('Tag', 'txtRedraws'); redraws = str2num(get(h, 'String')); if isempty(redraws), set(hm,'String','Please select how many redraws are needed.') break else if strcmp(error_method, 'Cross-Validation'), if (redraws < 2), set(hm, 'String', 'Number of redraws must be larger than 1.') break end else if (redraws < 1), set(hm,'String','Number of redraws must be larger than 0.') break end end end h = findobj('Tag', 'txtPrecentage'); percent = str2num(get(h, 'String')); if strcmp(error_method, 'Holdout'), if isempty(percent), set(hm,'String','Please select the percentage of training vectors.') break else if (floor(percent/100*length(targets)) < 1), set(hm,'String','Number training vectors must be larger than 0.') break end end end %Now that the data is OK, start working set(gcf,'pointer','watch'); %Some variable definitions Nclasses = find_classes(targets); %Number of classes in targets test_err = zeros(Nalgorithms,redraws); train_err = zeros(Nalgorithms,redraws); for k = 1: Nalgorithms, for i = 1: redraws, set(hm, 'String', [Chosen_algorithms(k).Name ' algorithm: Processing iteration ' num2str(i) ' of ' num2str(redraws) ' iterations...']); %Make a draw according to the error method chosen L = length(targets); switch error_method case cellstr('Resubstitution') test_indices = 1:L; train_indices = 1:L; case cellstr('Holdout') [test_indices, train_indices] = make_a_draw(floor(percent/100*L), L); case cellstr('Cross-Validation') chunk = floor(L/redraws); test_indices = 1 + (i-1)*chunk : i * chunk; train_indices = [1:(i-1)*chunk, 1+i * chunk:L]; end train_features = features(:, train_indices); train_targets = targets (:, train_indices); test_features = features(:, test_indices); test_targets = targets (:, test_indices); param = str2num(Chosen_algorithms(k).Parameter); if isempty(param), param = Chosen_algorithms(k).Parameter; end D = feval(Chosen_algorithms(k).Name, train_features, train_targets, param, region); [classify, err] = classification_error(D, train_features, train_targets, region); train_err(k,i) = err; [classify, err] = classification_error(D, test_features, test_targets, region); test_err(k,i) = err; end end hDisp = findobj('Tag','popErrorDisplay'); sDisp = get(hDisp,'String'); switch char(sDisp(get(hDisp,'Value'))), case 'Test error' if (redraws > 1), err = mean(test_err'); else err = test_err; end case 'Train error' if (redraws > 1), err = mean(train_err'); else err = train_err'; end otherwise if (redraws > 1), err = mean(test_err')*length(test_targets)+mean(train_err')*length(train_targets); err = err / (length(test_targets)+length(train_targets)); else err = test_err*length(test_targets)+train_err*length(train_targets); err = err / (length(test_targets)+length(train_targets)); end end hBayes = findobj('Tag','chkBayes'); if ((get(hBayes, 'Value')) & (exist('distribution_parameters'))), if ~isempty(distribution_parameters) Dbayes = decision_region(distribution_parameters, region); [classify, Bayes_err] = classification_error(Dbayes, features, targets, region); err(length(err)+1) = Bayes_err; Nalgorithms = Nalgorithms + 1; Chosen_algorithms(Nalgorithms).Name='Bayes err '; end end %Plot the results figure bar(err) title('Average classification errors') for k=1:Nalgorithms, str = deblank(Chosen_algorithms(k).Name); str(findstr(str,'_')) = ' '; h=text(k,err(k)+.02,str); set(h,'HorizontalAlignment','Center') set(h,'FontSize',12) %set(h,'Color',[1 1 1]) end ax = axis;ax(3)=0;ax(4)=max(1,max(err)); axis(ax) s = 'Finished!'; set(hm, 'String', s); set(hFigure,'pointer','arrow'); assignin('base','final_errors',err) else %Predict performance a = zeros(1, Nalgorithms); for k = 1: Nalgorithms, set(hm, 'String', [Chosen_algorithms(k).Name ' algorithm']); param = str2num(Chosen_algorithms(k).Parameter); if isempty(param), param = Chosen_algorithms(k).Parameter; end a(k) = predict_performance(Chosen_algorithms(k).Name, param, features, targets, region); end %Plot the results figure bar(a) title('Prediction values') for k=1:Nalgorithms, str = deblank(Chosen_algorithms(k).Name); str(findstr(str,'_')) = ' '; h=text(k,a(k)+.02,str); set(h,'HorizontalAlignment','Center') set(h,'FontSize',12) %set(h,'Color',[1 1 1]) end assignin('base','final_predictions',a) end otherwise error('Unknown commands') end