www.gusucode.com > nnet 案例源码 matlab代码程序 > nnet/classify_wine_demo.m
%% Wine Classification % This example illustrates how a pattern recognition neural network can % classify wines by winery based on its chemical characteristics. % Copyright 2010-2012 The MathWorks, Inc. %% The Problem: Classify Wines % In this example we attempt to build a neural network that can classify % wines from three wineries by thirteen attributes: % % * Alcohol % * Malic acid % * Ash % * Alcalinity of ash % * Magnesium % * Total phenols % * Flavanoids % * Nonflavanoid phenols % * Proanthocyanins % * Color intensity % * Hue % * OD280/OD315 of diluted wines % * Proline % % This is an example of a pattern recognition problem, where inputs are % associated with different classes, and we would like to create a neural % network that not only classifies the known wines properly, but can % generalize to accurately classify wines that were not used to design % the solution. % %% Why Neural Networks? % Neural networks are very good at pattern recognition problems. A neural % network with enough elements (called neurons) can classify any data with % arbitrary accuracy. They are particularly well suited for complex % decision boundary problems over many variables. Therefore neural networks % are a good candidate for solving the wine classification problem. % % The thirteeen neighborhood attributes will act as inputs to a neural % network, and the respective target for each will be a 3-element class % vector with a 1 in the position of the associated winery, #1, #2 or #3. % % The network will be designed by using the attributes of neighborhoods % to train the network to produce the correct target classes. % %% Preparing the Data % Data for classification problems are set up for a neural network by % organizing the data into two matrices, the input matrix X and the target % matrix T. % % Each ith column of the input matrix will have thirteen elements % representing a wine whose winery is already known. % % Each corresponding column of the target matrix will have three elements, % consisting of two zeros and a 1 in the location of the associated % winery. % % Here such a dataset is loaded. [x,t] = wine_dataset; %% % We can view the sizes of inputs X and targets T. % % Note that both X and T have 178 columns. These represent 178 wine sample % attributes (inputs) and associated winery class vectors (targets). % % Input matrix X has thirteen rows, for the thirteen attributes. Target % matrix T has three rows, as for each example we have three possible % wineries. size(x) size(t) %% Pattern Recognition with a Neural Network % The next step is to create a neural network that will learn to classify % the wines. % % Since the neural network starts with random initial weights, the results % of this example will differ slightly every time it is run. The random seed % is set to avoid this randomness. However this is not necessary for your % own applications. setdemorandstream(391418381) %% % Two-layer (i.e. one-hidden-layer) feed forward neural networks can learn % any input-output relationship given enough neurons in the hidden layer. % Layers which are not output layers are called hidden layers. % % We will try a single hidden layer of 10 neurons for this example. In % general, more difficult problems require more neurons, and perhaps more % layers. Simpler problems require fewer neurons. % % The input and output have sizes of 0 because the network has not yet % been configured to match our input and target data. This will happen % when the network is trained. net = patternnet(10); view(net) %% % Now the network is ready to be trained. The samples are automatically % divided into training, validation and test sets. The training set is % used to teach the network. Training continues as long as the network % continues improving on the validation set. The test set provides a % completely independent measure of network accuracy. % % The NN Training Tool shows the network being trained and the algorithms % used to train it. It also displays the training state during training % and the criteria which stopped training will be highlighted in green. % % The buttons at the bottom open useful plots which can be opened during % and after training. Links next to the algorithm names and plot buttons % open documentation on those subjects. [net,tr] = train(net,x,t); nntraintool %% % To see how the network's performance improved during training, either % click the "Performance" button in the training tool, or call PLOTPERFORM. % % Performance is measured in terms of mean squared error, and shown in % log scale. It rapidly decreased as the network was trained. % % Performance is shown for each of the training, validation and test sets. % The version of the network that did best on the validation set is % was after training. plotperform(tr) %% Testing the Neural Network % The mean squared error of the trained neural network can now be measured % with respect to the testing samples. This will give us a sense of how % well the network will do when applied to data from the real world. % % The network outputs will be in the range 0 to 1, so we can use *vec2ind* % function to get the class indices as the position of the highest element % in each output vector. testX = x(:,tr.testInd); testT = t(:,tr.testInd); testY = net(testX); testIndices = vec2ind(testY) %% % Another measure of how well the neural network has fit the data is the % confusion plot. Here the confusion matrix is plotted across all samples. % % The confusion matrix shows the percentages of correct and incorrect % classifications. Correct classifications are the green squares on the % matrices diagonal. Incorrect classifications form the red squares. % % If the network has learned to classify properly, the percentages in the % red squares should be very small, indicating few misclassifications. % % If this is not the case then further training, or training a network % with more hidden neurons, would be advisable. plotconfusion(testT,testY) %% % Here are the overall percentages of correct and incorrect classification. [c,cm] = confusion(testT,testY) fprintf('Percentage Correct Classification : %f%%\n', 100*(1-c)); fprintf('Percentage Incorrect Classification : %f%%\n', 100*c); %% % A third measure of how well the neural network has fit data is the % receiver operating characteristic plot. This shows how the false % positive and true positive rates relate as the thresholding of outputs % is varied from 0 to 1. % % The farther left and up the line is, the fewer false positives need to % be accepted in order to get a high true positive rate. The best % classifiers will have a line going from the bottom left corner, to the % top left corner, to the top right corner, or close to that. plotroc(testT,testY) %% % This example illustrated how to design a neural network that classifies % wines into three wineries from each wine's characteristics. % % Explore other examples and the documentation for more insight into neural % networks and their applications.