www.gusucode.com > stats 源码程序 matlab案例代码 > stats/ClassificationWithManyCategoricalLevels1Example.m
%% Classification with Many Categorical Levels % This example shows how to train an ensemble of classification trees using % data containing predictors with many categorical levels. %% % Generally, you cannot use classification with more than 31 levels in any % categorical predictor. However, two boosting algorithms can classify data % with many categorical predictor levels and binary responses: |LogitBoost| % and |GentleBoost|. For details, see <docid:stats_ug.bsw8au3 LogitBoost> % and <docid:stats_ug.bsw8avb GentleBoost>. %% % This example uses demographic data from the U.S. Census Bureau, available % at <http://archive.ics.uci.edu/ml/datasets/Hepatitis UCI % Machine Learning Data Repository>. %% % The objective of the researchers who posted the data is predicting % whether an individual makes over $50,000 a year, based on a set of % characteristics. You can see details of the data, including predictor % names, in the |adult.names| file at the site. %% % Load the |'adult.data'| file from the UCI Machine Learning Data Repository. % Specify a cell array of character vectors containing the variable names. adult = urlread(['http://archive.ics.uci.edu/ml/'... 'machine-learning-databases/adult/adult.data']); VarNames = {'age' 'workclass' 'fnlwgt' 'education' 'educationNum'... 'maritalStatus' 'occupation' 'relationship' 'race'... 'sex' 'capitalGain' 'capitalLoss'... 'hoursPerWeek' 'nativeCountry' 'income'}; %% % |adult.data| represents missing data as |'?'|. Replace instances of % missing data with an empty character vector. Use |textscan| to put the data into a % cell array of character vectors. adult = strrep(adult,'?',''); adult = textscan(adult,'%f%s%f%s%f%s%s%s%s%s%f%f%f%s%s',... 'Delimiter',',','TreatAsEmpty',''); %% % The name-value pair argument |TreatAsEmpty| converts all observations % corresponding to numeric variables to |NaN| if the observation is an % empty character vector. %% % Since the variables are heterogeneous, put the set into a tabular array. adult = table(adult{:},'VariableNames',VarNames); %% % Some categorical variables have many levels. Plot the number of levels of % each categorical predictor. cat = varfun(@iscellstr,adult(:,1:end - 1),... 'OutputFormat','uniform'); % Logical flag for categorical variables catVars = find(cat); % Indices of categorical variables countCats = @(var)numel(categories(nominal(var))); numCat = varfun(@(var)countCats(var),adult(:,catVars),... 'OutputFormat','uniform'); figure barh(numCat); h = gca; h.YTickLabel = VarNames(catVars); ylabel 'Predictor' xlabel 'Number of categories' %% % The anonymous function |countCats| converts a predictor to a nominal % array, then counts the unique, nonempty categories of the predictor. % Predictor 14 (|'nativeCountry'|) has more than 40 categorical levels. For % binary classification, <docid:stats_ug.bt6cr7t> uses a computational % shortcut to find an optimal split for categorical predictors with many % categories. For classification with more than two classes, you can choose % a heuristic algorithm to find a good split. For details, see % <docid:stats_ug.btttehe>. %% % Specify the predictor matrix using % |classreg.regr.modelutils.predictormatrix| and the response vector. X = classreg.regr.modelutils.predictormatrix(adult,'ResponseVar',... size(adult,2)); Y = nominal(adult.income); %% % |X| is a numeric matrix; |predictormatrix| converts all categorical % variables into group indices. The name-value pair argument |ResponseVar| % indicates that the last column is the response variable, and excludes it % from the predictor matrix. |Y| is a nominal, categorical array. %% % Train classification ensembles using both |LogitBoost| and |GentleBoost|. rng(1); % For reproducibility LBEnsemble = fitensemble(X,Y,'LogitBoost',300,'Tree',... 'CategoricalPredictors',cat,'PredictorNames',VarNames(1:end-1),... 'ResponseName','income'); GBEnsemble = fitensemble(X,Y,'GentleBoost',300,'Tree',... 'CategoricalPredictors',cat,'PredictorNames',VarNames(1:end-1),... 'ResponseName','income'); %% % Examine the resubstitution error for both ensembles. figure plot(resubLoss(LBEnsemble,'Mode','cumulative')) hold on plot(resubLoss(GBEnsemble,'Mode','cumulative'),'r--') hold off xlabel('Number of trees') ylabel('Resubstitution error') legend('LogitBoost','GentleBoost','Location','NE') %% % The |GentleBoost| algorithm has a slightly smaller resubstitution error. %% % Estimate the generalization error for both algorithms by cross validation. CVLBEnsemble = crossval(LBEnsemble,'KFold',5); CVGBEnsemble = crossval(GBEnsemble,'KFold',5); figure plot(kfoldLoss(CVLBEnsemble,'Mode','cumulative')) hold on plot(kfoldLoss(CVGBEnsemble,'Mode','cumulative'),'r--') hold off xlabel('Number of trees') ylabel('Cross-validated error') legend('LogitBoost','GentleBoost','Location','NE') %% % The cross-validated loss is nearly the same as the resubstitution error.