www.gusucode.com > nnet 案例源码 matlab代码程序 > nnet/TransferLearningUsingConvolutionalNeuralNetworksExample.m

    %% Transfer Learning Using Convolutional Neural Networks  
% Fine-tune a convolutional neural network pretrained on digit images to
% learn the features of letter images. Transfer learning is considered as
% the transfer of knowledge from one learned task to a new task in machine
% learning [1]. In the context of neural networks, it is transferring
% learned features of a pretrained network to a new problem. Training a
% convolutional neural network from the beginning in each case usually is
% not effective when there is not sufficient amount of training data. The
% common practice in deep learning for such cases is to use a network that
% is trained on a large data set for a new problem. While the initial
% layers of the pretrained network can be fixed, the last few layers must
% be fine-tuned to learn the specific features of the new data set.
% Transfer learning usually results in faster training times than training
% a new convolutional neural network because you do not need to estimate
% all the parameters in the new network.
%
% *NOTE:* Training a convolutional neural network requires Parallel
% Computing Toolbox(TM) and a CUDA(R)-enabled NVIDIA(R) GPU with compute capability
% 3.0 or higher.

%% 
% Load the sample data as an |ImageDatastore|. 
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');
digitData = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,'LabelSource','foldernames');

%%
% The data store contains 10000 synthetic images of digits 0–9. The images
% are generated by applying random transformations to digit images created
% using different fonts. Each digit image is 28-by-28 pixels.  

%% 
% Display some of the images in the datastore. 
% 
for i = 1:20
    subplot(4,5,i);
    imshow(digitData.Files{i});
end  

%% 
% Check the number of images in each digit category. 
digitData.countEachLabel 

%%
% The data contains an unequal number of images per category.  

%% 
% To balance the number of images for each digit in the training set, first
% find the minimum number of images in a category. 
minSetCount = min(digitData.countEachLabel{:,2})  

%% 
% Divide the dataset so that each category in the training set has 494 images
% and the testing set has the remaining images from each label. 
trainingNumFiles = round(minSetCount/2);
rng(1) % For reproducibility
[trainDigitData,testDigitData] = splitEachLabel(digitData,...
				trainingNumFiles,'randomize'); 

%%
% |splitEachLabel| splits the image files in |digitData| into two new datastores,
% |trainDigitData| and |testDigitData|.  

%%
% Create the layers for the convolutional neural network. 
layers = [imageInputLayer([28 28 1])
	 convolution2dLayer(5,20)
	 reluLayer()
	 maxPooling2dLayer(2,'Stride',2)
	 fullyConnectedLayer(10)
	 softmaxLayer()
	 classificationLayer()];  

%% 
% Create the training options. Set the maximum number of epochs at 20, and
% start the training with an initial learning rate of 0.001. 
options = trainingOptions('sgdm','MaxEpochs',20,...
	'InitialLearnRate',0.001);  

%% 
% Train the network using the training set and the options you defined in
% the previous step. 
convnet = trainNetwork(trainDigitData,layers,options);  

%% 
% Test the network using the testing set and compute the accuracy. 
YTest = classify(convnet,testDigitData);
TTest = testDigitData.Labels;
accuracy = sum(YTest == TTest)/numel(YTest) 

%%
% Accuracy is the ratio of the number of true labels in the test data matching
% the classifications from |classify|, to the number of images in the test
% data. In this case 99.78% of the digit estimations match the true digit
% values in the test set.  

%% 
% Now, suppose you would like to use the trained network |net| to predict
% classes on a new set of data. Load the letters training data. 
load lettersTrainSet.mat 

%%
% |XTrain| contains 1500 28-by-28 grayscale images of the letters A, B,
% and C in a 4-D array. |TTrain| contains the categorical array of the letter
% labels.  

%%
% Display some of the letter images.
figure;
for j = 1:20
    subplot(4,5,j);
    selectImage = datasample(XTrain,1,4);
    imshow(selectImage,[]);
end

%% 
% The pixel values in |XTrain| are in the range [0 1]. The digit data used
% in training the network |net| were in [0 255]; scale the letters data
% between [0 255]. 
XTrain = XTrain*255;  

%% 
% The last three layers of the trained network |net| are tuned for the digit
% dataset, which has 10 classes. The properties of these layers depend on
% the classification task. Display the fully connected layer (|fullyConnectedLayer|). 
convnet.Layers(end-2)  

%% 
% Display the last layer (|classificationLayer|). 
convnet.Layers(end)  

%% 
% These three layers must be fine-tuned for the new classification problem.
% Extract all the layers but the last three from the trained network, |net|. 
layersTransfer = convnet.Layers(1:end-3);  

%% 
% The letters data set has three classes. Add a new fully connected layer
% for three classes, and increase the learning rate for this layer. 

layersTransfer(end+1) = fullyConnectedLayer(3,...
              'WeightLearnRateFactor',10,...
	          'BiasLearnRateFactor',20);  

%%
% |WeightLearnRateFactor| and |BiasLearnRateFactor| are multipliers of the
% global learning rate for the fully connected layer. 
%% 
% Add a softmax layer and a classification output layer. 
layersTransfer(end+1) = softmaxLayer();
layersTransfer(end+1) = classificationLayer();  

%% 
% Create the options for transfer learning. You do not have to train for
% many epochs (|MaxEpochs| can be lower than before). Set the |InitialLearnRate|
% at a lower rate than used for training |net| to improve convergence by
% taking smaller steps. 
optionsTransfer = trainingOptions('sgdm',...
         'MaxEpochs',5,...
         'InitialLearnRate',0.000005,...
         'Verbose',true);  

%% 
% Perform transfer learning. 
convnetTransfer = trainNetwork(XTrain,TTrain,...
		   layersTransfer,optionsTransfer);  

%% 
% Load the letters test data. Similar to the letters training data, scale
% the testing data between [0 255], because the training data were between
% that range. 
load lettersTestSet.mat
XTest = XTest*255;  

%% 
% Test the accuracy. 
YTest = classify(convnetTransfer,XTest);
accuracy = sum(YTest == TTest)/numel(TTest)    

%% 
% *References* 
%
% [1] Sinno Jialin Pan and Qiang Yang _A Survey on Transfer Learning_, IEEE Transactions on Knowledge and Data Engineering, Vol. 22, No.10, October, 2010.