www.gusucode.com > IPCV_Eval_Kit_R2019a_0ce6858工具箱matlab程序源码 > IPCV_Eval_Kit_R2019a_0ce6858/code/demo_files/I5_06_3_2_semanticSeg_SegNet.m

    %% SegNet偵傛傞僙儅儞僥傿僢僋僙僌儊儞僥乕僔儑儞

%% 弶婜壔
clear; close all ;clc; rng('default');

%% 妛廗僨乕僞偺弨旛
dataSetDir = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageDir = fullfile(dataSetDir,'trainingImages');
imds = imageDatastore(imageDir);

%% 儔儀儖僨乕僞偺弨旛
classNames = ["triangle","background"];
labelIDs   = [255 0];
labelDir = fullfile(dataSetDir,'trainingLabels');
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);

%% 妛廗僨乕僞偲儔儀儖僨乕僞偺壜帇壔
I = read(imds);
C = read(pxds);

I = imresize(I,5);
L = imresize(uint8(C),5);
figure, imshowpair(I,L,'montage')

%% 妛廗僨乕僞偺弨旛
augmenter = imageDataAugmenter('RandRotation',[-10 10],'RandXReflection',true)
trainingData = pixelLabelImageDatastore(imds,pxds,'DataAugmentation',augmenter)

%% SegNet偺弨旛
imageSize = [32 32];
numClasses = 2;
lgraph = segnetLayers(imageSize,numClasses,2)

%% 妛廗僆僾僔儑儞
opts = trainingOptions('sgdm', ...
    'InitialLearnRate',1e-3, ...
    'MaxEpochs',20, ...
    'MiniBatchSize',64,...
    'Plots','training-progress');

%% 儔儀儖偺昿搙偐傜廳傒寁嶼
tbl = countEachLabel(trainingData)
totalNumberOfPixels = sum(tbl.PixelCount);
frequency = tbl.PixelCount / totalNumberOfPixels;
classWeights = 1./frequency
pxLayer = pixelClassificationLayer('Name','labels','ClassNames',tbl.Name,'ClassWeights',classWeights)
lgraph = removeLayers(lgraph,'pixelLabels');
lgraph = addLayers(lgraph, pxLayer);
lgraph = connectLayers(lgraph,'softmax','labels');
analyzeNetwork(lgraph);

%% 妛廗
net = trainNetwork(trainingData,lgraph,opts);

%% 僥僗僩夋憸偱昡壙
testImage = imread('triangleTest.jpg');
imshow(testImage)
C = semanticseg(testImage,net);
B = labeloverlay(testImage,C);
imshow(B)

%%
% Copyright 2018 The MathWorks, Inc.