www.gusucode.com > IPCV_Eval_Kit_R2019a_0ce6858工具箱matlab程序源码 > IPCV_Eval_Kit_R2019a_0ce6858/code/demo_files/I5_06_5_1_videoClassification.m
%% 5.6.5.1 僨傿乕僾儔乕僯儞僌丗摦夋偺暘椶 %% 妛廗嵪傒儌僨儖偺撉傒崬傒 netCNN = googlenet; %% 僨乕僞偺撉傒崬傒 % <http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/ % HMDB: a large human motion database>偐傜RAR傪僟僂儞儘乕僪偟夝搥(|hmdb51_org)丅 % 51僋儔僗偺7000屄偺價僨僆僔乕働儞僗丅"堸傓"丄"憱傞","庤傪怳傞"側偳丅| % % 僼傽僀儖柤偲儔儀儖傪庢摼偡傞偨傔偺僒億乕僩娭悢傪巊偆 dataFolder = "hmdb51_org"; if ~exist(fullfile(pwd,dataFolder),'dir') error("僨乕僞僙僢僩hmdb51_org傪僟僂儞儘乕僪偟偰夝搥偟偰偔偩偝偄" +... "http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/"); end [files,labels] = hmdb51Files(dataFolder); %% 價僨僆僨乕僞傪撉傒崬傓 % HxWxCx_S偺攝楍丄崅偝丄暆丄僠儍僱儖悢丄僼儗乕儉悢偺弴斣 idx = 1; filename = files(idx); video = readVideo(filename); size(video) %% 懳墳偡傞儔儀儖傪妋擣 labels(idx) %% imshow偱昞帵 % double宆偺応崌偼抣偑[0 1]偺斖埻偵偁傞昁梫偑偁傞偺偱255偱惓婯壔丅 numFrames = size(video,4); figure for i = 1:numFrames frame = video(:,:,:,i); imshow(frame/255); drawnow end %% 價僨僆偐傜摿挜儀僋僩儖偺拪弌 % 幚峴偵偼30暘埲忋偐偐傞偺偱拲堄丅 inputSize = netCNN.Layers(1).InputSize(1:2); layerName = "pool5-7x7_s1"; tempFile = fullfile(tempdir,"hmdb51_org.mat"); if exist(tempFile,'file') load(tempFile,"sequences") else numFiles = numel(files); sequences = cell(numFiles,1); for i = 1:numFiles fprintf("Reading file %d of %d...\n", i, numFiles) video = readVideo(files(i)); video = centerCrop(video,inputSize); sequences{i,1} = activations(netCNN,video,layerName,'OutputAs','columns'); end save(tempFile,"sequences","-v7.3"); end %% 摿挜儀僋僩儖偺僒僀僘傪妋擣 % DxS偺攝楍偵側偭偰偄傞丅D偼摿挜儀僋僩儖偺僒僀僘丅S偼價僨僆偺僼儗乕儉悢丅 sequences(1:10) %% 妛廗僨乕僞偺弨旛 % 妛廗梡偲専掕梡偵僨乕僞僙僢僩傪9:1偵暘妱 numObservations = numel(sequences); idx = randperm(numObservations); N = floor(0.9 * numObservations); idxTrain = idx(1:N); sequencesTrain = sequences(idxTrain); labelsTrain = labels(idxTrain); idxValidation = idx(N+1:end); sequencesValidation = sequences(idxValidation); labelsValidation = labels(idxValidation); %% 挿傔偺價僨僆偼彍嫀偡傞 % 僷僨傿儞僌偵傛傞埆塭嬁傪旔偗傞偨傔偵挿偡偓傞僔乕働儞僗偼彍嫀偡傞丅 numObservationsTrain = numel(sequencesTrain); sequenceLengths = zeros(1,numObservationsTrain); for i = 1:numObservationsTrain sequence = sequencesTrain{i}; sequenceLengths(i) = size(sequence,2); end figure histogram(sequenceLengths) title("Sequence Lengths") xlabel("Sequence Length") ylabel("Frequency") %% 400僼儗乕儉埲忋偺傕偺偼彮悢攈側偺偱彍嫀偡傞丅 maxLength = 400; idx = sequenceLengths > maxLength; sequencesTrain(idx) = []; labelsTrain(idx) = []; %% LSTM僱僢僩儚乕僋偺嶌惉 % BiLSTM儗僀儎乕偼2000偺塀傟憌傪愝掕丅 % 弌椡偼1屄偺儔儀儖側偺偱'OutputMode'傪'last'愝掕丅 % fully connected layer偼暘椶悢偵愝掕丅 numFeatures = size(sequencesTrain{1},1); numClasses = numel(categories(labelsTrain)); layers = [ sequenceInputLayer(numFeatures,'Name','sequence') bilstmLayer(2000,'OutputMode','last','Name','bilstm') dropoutLayer(0.5,'Name','drop') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax') classificationLayer('Name','classification')]; %% 僩儗乕僯儞僌僆僾僔儑儞偺愝掕 % 儈僯僶僢僠偛偲偵嵟彫偺僼儗乕儉悢偲摨偠偵側傞傛偆偵愗傝庢傝丅 % 僄億僢僋偛偲偵僨乕僞傪僔儍僢僼儖丅 miniBatchSize = 16; numObservations = numel(sequencesTrain); numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'InitialLearnRate',1e-4, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{sequencesValidation,labelsValidation}, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false); %% LSTM僱僢僩儚乕僋偺妛廗 [netLSTM,info] = trainNetwork(sequencesTrain,labelsTrain,layers,options); %% 暘椶惛搙偺妋擣丅 YPred = classify(netLSTM,sequencesValidation,'MiniBatchSize',miniBatchSize); YValidation = labelsValidation; accuracy = mean(YPred == YValidation) %% 價僨僆暘椶僱僢僩儚乕僋偺慻傒棫偰 % 忯傒崬傒儗僀儎乕偺捛壛 cnnLayers = layerGraph(netCNN); % 傾僋僥傿儀乕僔儑儞偺憌傛傝屻偺憌偼嶍彍丅 layerNames = ["data" "pool5-drop_7x7_s1" "loss3-classifier" "prob" "output"]; cnnLayers = removeLayers(cnnLayers,layerNames); % 僔乕働儞僗僀儞僾僢僩儗僀儎乕傪愭摢偵捛壛 % 僀儊乕僕僔乕働儞僗傪埖偆偨傔偵僔乕働儞僗僀儞僾僢僩儗僀儎乕傪掕媊丅 % 'Normalization'僆僾僔儑儞傪'zerocenter'偵偟丄 % 'Mean'僆僾僔儑儞傪GoogLeNet偺averageImage偵愝掕丅 inputSize = netCNN.Layers(1).InputSize(1:2); averageImage = netCNN.Layers(1).AverageImage; inputLayer = sequenceInputLayer([inputSize 3], ... 'Normalization','zerocenter', ... 'Mean',averageImage, ... 'Name','input'); %% 忯傒崬傒傪夋憸偺僔乕働儞僗偦傟偧傟偵偐偗傞偨傔偵sequence folding layer傪巊梡偡傞丅 layers = [ inputLayer sequenceFoldingLayer('Name','fold')]; lgraph = addLayers(cnnLayers,layers); lgraph = connectLayers(lgraph,"fold/out","conv1-7x7_s2"); % LSTM儗僀儎乕傪捛壛 % LSTM僱僢僩儚乕僋偐傜sequence input layer傪彍嫀丅 lstmLayers = netLSTM.Layers; lstmLayers(1) = []; % sequence folding layer丄flatten layer丄LSTM layers傪捛壛丅 layers = [ sequenceUnfoldingLayer('Name','unfold') flattenLayer('Name','flatten') lstmLayers]; lgraph = addLayers(lgraph,layers); % 忯傒崬傒憌偺嵟廔憌("pool5-7x7_s1")傪sequence unfolding layer ("unfold/in")偵愙懕丅 lgraph = connectLayers(lgraph,"pool5-7x7_s1","unfold/in"); % unfolding layer偐傜僔乕働儞僗峔憿傪暅尦偡傞偨傔偵丄 % sequence folding layer偺|"miniBatchSize"弌椡傪|sequence % unfolding layer偵愙懕丅 lgraph = connectLayers(lgraph,"fold/miniBatchSize","unfold/miniBatchSize"); %% analyzeNetwork娭悢傪巊偭偰僱僢僩儚乕僋偺惍崌傪妋擣丅 analyzeNetwork(lgraph) %% assembleNetwork娭悢傪巊偭偰僱僢僩儚乕僋傪慻傒忋偘 net = assembleNetwork(lgraph) %% 怴偟偄價僨僆偵懳偟偰暘椶傪偐偗傞 % "pushup.mp4"價僨僆傪撉傒崬傫偱拞墰愗傝弌偟丅 filename = "pushup.mp4"; video = readVideo(filename); %% 壜帇壔 numFrames = size(video,4); figure for i = 1:numFrames frame = video(:,:,:,i); imshow(frame/255); drawnow end %% 暘椶傪幚峴 % classify娭悢偵偼擖椡價僨僆楍傪僙儖攝楍偲偟偰梌偊傞昁梫偑偁傞丅| video = centerCrop(video,inputSize); YPred = classify(net,{video}) %% 僒億乕僩娭悢 % 價僨僆僨乕僞傪撉傒弌偟丅 function video = readVideo(filename) vr = VideoReader(filename); H = vr.Height; W = vr.Width; C = 3; % Preallocate video array numFrames = floor(vr.Duration * vr.FrameRate); video = zeros(H,W,C,numFrames); % Read frames i = 0; while hasFrame(vr) i = i + 1; video(:,:,:,i) = readFrame(vr); end % Remove unallocated frames if size(video,4) > i video(:,:,:,i+1:end) = []; end end %% % 拞墰愗傝弌偟偲擖椡夋憸僒僀僘偵崌傢偣偰儕僒僀僘丅 function videoResized = centerCrop(video,inputSize) sz = size(video); if sz(1) < sz(2) % Video is landscape idx = floor((sz(2) - sz(1))/2); video(:,1:(idx-1),:,:) = []; video(:,(sz(1)+1):end,:,:) = []; elseif sz(2) < sz(1) % Video is portrait idx = floor((sz(1) - sz(2))/2); video(1:(idx-1),:,:,:) = []; video((sz(2)+1):end,:,:,:) = []; end videoResized = imresize(video,inputSize(1:2)); end %% % _Copyright 2019 The MathWorks, Inc._