www.gusucode.com > IPCV_Eval_Kit_R2019a_0ce6858工具箱源码程序matlab > IPCV_Eval_Kit_R2019a_0ce6858/code/demo_files/I5_06_1_myAutoencoderDigits_R2016b.m

    %% 懡憌僯儏乕儔儖僱僢僩儚乕僋丗Deep Neural Network(NN)偺妛廗偵傛傞悢帤幆暿
%  Autoencoder偺庤朄偱妛廗偟偨Encoder傪懡抜廳偹偨幆暿婍丗 Stacked Autoencoder
%  GPU傗暲楍張棟偺僆僾僔儑儞傪巊偆偵偼丄Parallel Computing Toolbox 偺儔僀僙儞僗偑昁梫
clc;clear;close all;imtool close all;rng('default')

%% 妛廗梡夋憸偺撉崬傒  乮悢帤傪儔儞僟儉側傾僼傿儞曄姺偱曄宍偟偨傕偺傪巊梡乯
%     xTrainImages丗妛廗梡夋憸丗28x28僺僋僙儖偺夋憸偑5003枃    (僙儖攝楍)
%     tTrain      丗儔儀儖乮嫵巘僨乕僞乯10x5003
[xTrainImages, tTrain] = digitTrainCellArrayData;

%% 妛廗梡夋憸偺堦晹(80枃)傪昞帵
figure; montage(reshape([xTrainImages{1:80}], [28 28 1 80]), 'Size', [8,10]);

%% 儔儀儖(嫵巘僨乕僞)偺妋擣
openvar('tTrain')     % 10峴栚偼丄暥帤0偵懳墳

%% [戞堦塀傟憌]: Autoencoder偵傛傞1夞栚偺妛廗 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Autoencoder偵傛傞NN偺妛廗丗弌椡抣偑擖椡抣偲摨偠偵側傞傛偆偵丄嫵巘側偟妛廗
%                            塀傟憌偺僒僀僘偑擖椡傛傝彮側偔側傞偙偲偱丄擖椡忣曬傪埑弅
% Autoencoder僋儔僗傪巊梡
% GPU偵傛傞崅懍壔傕壜擻
hiddenSize1 = 100;    % Encoder偺悢(僯儏乕儘儞偺悢)
autoenc1 = trainAutoencoder(xTrainImages, hiddenSize1, ...
                                'MaxEpochs',400, ...      % 妛廗夞悢乮悽戙乯
                                'L2WeightRegularization',0.004, ...        % impact of an L2 regularizer for the weights of the network (and not the biases). This should typically be quite small.
                                'SparsityRegularization',4, ...            % impact of a sparsity regularizer, which attempts to enforce a constraint on the sparsity of the output from the hidden layer
                                'SparsityProportion',0.15, ...
                                'ScaleData', false);
%% 僂僃僀僩偺壜帇壔丗Encoder偑妛廗偟偨摿挜
%    僒僀僘丗擖椡28x28=784僺僋僙儖(僲乕僪)丄戞堦憌(塀傟/拞娫憌)100丄戞堦憌偺弌椡100丄 弌椡28x28=784僺僋僙儖丄塀傟(拞娫)憌100
%    僲乕僪枅偵丄784師偺僂僃僀僩w偲掕悢僶僀傾僗b偑偁傞丅
%      (100屄偺僯儏乕儘儞偑偦傟偧傟妶惈壔偡傞擖椡僷僞乕儞丗偦傟偧傟嬋傝傗捈慄僷僞乕儞傪昞尰)
b1 = autoenc1.EncoderBiases             % 戞堦憌偺僶僀傾僗
w1 = autoenc1.EncoderWeights;           % 戞堦憌偺僂僃僀僩
figure; plotWeights(autoenc1);          % 僂僃僀僩偺壜帇壔


%% [戞擇塀傟憌]丗1夞栚偺妛廗偱嶌偭偨Encoder偺弌椡傪梡偄丄2偮栚偺Autoencoder妛廗 %%%%%%%%%
% 堦抜栚偺Encoder偺丄妛廗夋憸偵懳偡傞弌椡(摿挜)傪寁嶼 => 妛廗2偺妛廗梡夋憸
feat1 = encode(autoenc1, xTrainImages);
%% 塀傟憌偺僒僀僘50偱妛廗
hiddenSize2 = 50;
autoenc2 = trainAutoencoder(feat1, hiddenSize2, ...
                              'MaxEpochs',100, ...
                              'L2WeightRegularization',0.002, ...
                              'SparsityRegularization',4, ...
                              'SparsityProportion',0.1, ...
                              'ScaleData', false);

%% [嵟廔憌]丗 擇抜栚偺弌椡50屄偐傜丄10僋儔僗傊幆暿偡傞嵟廔抜偺Softmax憌傪妛廗 %%%%%%%%
% 擇抜栚偺Encoder偺弌椡傪寁嶼
%   擖椡偼丄妛廗帪偵梡偄偨丄"妛廗夋憸(784x5000屄)偵懳偡傞堦抜栚偺Encoder弌椡"傪巊梡(100x5000屄)
feat2 = encode(autoenc2, feat1);     % 寢壥偼 50x5000
%% 嵟廔憌偺妛廗丗5000屄偺50師妛廗僨乕僞偵懳墳偡傞嫵巘僨乕僞(tTrain)傪巊梡
softnet = trainSoftmaxLayer(feat2, tTrain, 'MaxEpochs',400);

%% [寢崌]丗妛廗偟偨3偮偺憌傪寢崌\帵
deepnet = stack(autoenc1, autoenc2, softnet)           % network object
view(deepnet);

%% [僥僗僩梡夋憸傪暘椶]
% 僥僗僩夋憸~5000枃偺撉崬傒丗 僥僗僩夋憸:xTestImages丄儔儀儖:tTest(10x4997)
[xTestImages, tTest] = digitTestCellArrayData;
% 僙儖宍幃偺~5000屄偺僥僗僩梡夋憸傪丄奺楍偑堦偮偺夋憸僨乕僞(28x28=784)偺峴楍偵曄姺
xTestMatrix = reshape([xTestImages{:}], [28*28 4997]);   % 784峴4997楍偺峴楍

%% 暘椶
result1 = deepnet(xTestMatrix);        % 10x4997

%% ~5000夋憸偺擣幆寢壥偺偆偪丄嵟弶偺100屄傪暘椶媺蕚饡\帵(愒偑岆擣幆)
Ir = zeros([28,28,3,100]);      % 寢壥傪奿擺偡傞攝楍
for k = 1:100
  img = xTestImages{k};
  [~, maxI] = max(result1(:,k));
  if maxI == find(tTest(:,k))
      colorN = 'green';
  else
      colorN = 'red';
  end
  img = insertText(img, [0 0], mod(maxI,10), 'TextColor',colorN, 'FontSize',14, 'BoxOpacity',0, 'Font','Lucida Sans Typewriter Bold');
  Ir(:,:,:,k)=img;
end
figure;montage(Ir);

%% 崿崌峴楍偺昞帵
figure;plotconfusion(tTest, result1);

%% [旝挷惍] 岆嵎媡揱斃朄偵傛傝僂僃僀僩偺旝挷惍偟丄僥僗僩梡夋憸傪嵞暘椶
% 僙儖宍幃偺~5000屄偺妛廗梡夋憸傪丄奺楍偑堦偮偺夋憸僨乕僞(28x28=784)偺峴楍偵曄姺
xTrain = reshape([xTrainImages{:}], [28*28 5003]);   % 784峴5003楍偺峴楍
% 旝挷惍
deepnet = train(deepnet, xTrain, tTrain);

%% 嵞暘椶
result2 = deepnet(xTestMatrix);        % 10x4997
Ir = zeros([28,28,3,100]);      % 寢壥傪奿擺偡傞攝楍
for k = 1:100
  img = xTestImages{k};
  [~, maxI] = max(result2(:,k));
  if maxI == find(tTest(:,k))
      colorN = 'green';
  else
      colorN = 'red';
  end
  img = insertText(img, [0 0], mod(maxI,10), 'TextColor',colorN, 'FontSize',14, 'BoxOpacity',0, 'Font','Lucida Sans Typewriter Bold');
  Ir(:,:,:,k)=img;
end
figure;montage(Ir);

%% 崿崌峴楍偺昞帵
figure;plotconfusion(tTest, result2);

%% Copyright 2015 The MathWorks, Inc.