www.gusucode.com > IPCV_Eval_Kit_R2019a_0ce6858工具箱matlab程序源码 > IPCV_Eval_Kit_R2019a_0ce6858/code/demo_files/I5_06_4_1_semanticSeg_3DUnet.m

    %% 僨傿乕僾儔乕僯儞僌丗3D U-Net僙僌儊儞僥乕僔儑儞

%% 弶婜壔
clear; close all ;clc; rng('default');

%% 僒億乕僩娭悢偵僷僗傪捠偡
addpath(fullfile(matlabroot,'examples','deeplearning_shared','main'));

%% BraTS僨乕僞僙僢僩
% 壓婰偺乽Download Data乿偐傜乽Task01_BrainTumour.tar乿傪僟僂儞儘乕僪偟丄夝搥偟偰僼僅儖僟偵抲偔丅
% http://medicaldecathlon.com/
imageDir = fullfile(pwd,'BraTS');
if ~exist(imageDir,'dir')
    error('http://medicaldecathlon.com/ 偐傜 Task01_BrainTumour.tar傪僟僂儞儘乕僪偟丄夝搥偟偰偔偩偝偄丅');
end

%% 慜張棟(30暘傎偳偐偐傞)
sourceDataLoc = [imageDir filesep 'Task01_BrainTumour'];
preprocessDataLoc = fullfile(tempdir,'BraTS','preprocessedDataset');
preprocessBraTSdataset(preprocessDataLoc,sourceDataLoc);

%% 妛廗傗専徹偺偨傔偺儔儞僟儉僷僢僠拪弌僨乕僞僗僩傾傪嶌惉
volReader = @(x) matRead(x);
volLoc = fullfile(preprocessDataLoc,'imagesTr');
volds = imageDatastore(volLoc, ...
    'FileExtensions','.mat','ReadFcn',volReader);
labelReader = @(x) matRead(x);
lblLoc = fullfile(preprocessDataLoc,'labelsTr');
classNames = ["background","tumor"];
pixelLabelID = [0 1];
pxds = pixelLabelDatastore(lblLoc,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',labelReader);

%% 1偮偺儃儕儏乕儉偲儔儀儖傪妋擣
volume = preview(volds);
label = preview(pxds);
figure
h = labelvolshow(label,volume(:,:,:,1));
h.LabelVisibility(1) = 0;

%% 妛廗偺愝掕
patchSize = [64 64 64];
patchPerImage = 16;
miniBatchSize = 8;
patchds = randomPatchExtractionDatastore(volds,pxds,patchSize, ...
    'PatchesPerImage',patchPerImage);
patchds.MiniBatchSize = miniBatchSize;

%% 僨乕僞偺僆乕僌儊儞僥乕僔儑儞(儔儞僟儉偵夞揮偲斀揮)
dsTrain = transform(patchds,@augment3dPatch);

%% randomPatchExtrationDatastore傪嶌惉
volLocVal = fullfile(preprocessDataLoc,'imagesVal');
voldsVal = imageDatastore(volLocVal, ...
    'FileExtensions','.mat','ReadFcn',volReader);

lblLocVal = fullfile(preprocessDataLoc,'labelsVal');
pxdsVal = pixelLabelDatastore(lblLocVal,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',labelReader);

dsVal = randomPatchExtractionDatastore(voldsVal,pxdsVal,patchSize, ...
    'PatchesPerImage',patchPerImage);
dsVal.MiniBatchSize = miniBatchSize;

%% 3-D U-Net儗僀儎乕傪掕媊
inputSize = [64 64 64 4];
inputLayer = image3dInputLayer(inputSize,'Normalization','none','Name','input');

% 僄儞僐乕僟乕
numFiltersEncoder = [
    32 64; 
    64 128; 
    128 256];
layers = [inputLayer];
for module = 1:3
    modtag = num2str(module);
    encoderModule = [
        convolution3dLayer(3,numFiltersEncoder(module,1), ...
            'Padding','same','WeightsInitializer','narrow-normal', ...
            'Name',['en',modtag,'_conv1']);
        batchNormalizationLayer('Name',['en',modtag,'_bn']);
        reluLayer('Name',['en',modtag,'_relu1']);
        convolution3dLayer(3,numFiltersEncoder(module,2), ...
            'Padding','same','WeightsInitializer','narrow-normal', ...
            'Name',['en',modtag,'_conv2']);
        reluLayer('Name',['en',modtag,'_relu2']);
        maxPooling3dLayer(2,'Stride',2,'Padding','same', ...
            'Name',['en',modtag,'_maxpool']);
    ];
    
    layers = [layers; encoderModule];
end

% 僨僐乕僟乕
numFiltersDecoder = [
    256 512; 
    256 256; 
    128 128; 
    64 64];

decoderModule4 = [
    convolution3dLayer(3,numFiltersDecoder(1,1), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de4_conv1');
    reluLayer('Name','de4_relu1');
    convolution3dLayer(3,numFiltersDecoder(1,2), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de4_conv2');
    reluLayer('Name','de4_relu2');
    transposedConv3dLayer(2,numFiltersDecoder(1,2),'Stride',2, ...
        'Name','de4_transconv');
];

decoderModule3 = [
    convolution3dLayer(3,numFiltersDecoder(2,1), ...
        'Padding','same','WeightsInitializer','narrow-normal', ....
        'Name','de3_conv1');       
    reluLayer('Name','de3_relu1');
    convolution3dLayer(3,numFiltersDecoder(2,2), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de3_conv2'); 
    reluLayer('Name','de3_relu2');
    transposedConv3dLayer(2,numFiltersDecoder(2,2),'Stride',2, ...
        'Name','de3_transconv');
];

decoderModule2 = [
    convolution3dLayer(3,numFiltersDecoder(3,1), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de2_conv1');
    reluLayer('Name','de2_relu1');
    convolution3dLayer(3,numFiltersDecoder(3,2), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de2_conv2');
    reluLayer('Name','de2_relu2');
    transposedConv3dLayer(2,numFiltersDecoder(3,2),'Stride',2, ...
        'Name','de2_transconv');
];

% 嵟廔偺僨僐乕僟乕
numLabels = 2;
decoderModuleFinal = [
    convolution3dLayer(3,numFiltersDecoder(4,1), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de1_conv1');
    reluLayer('Name','de1_relu1');
    convolution3dLayer(3,numFiltersDecoder(4,2), ...
        'Padding','same','WeightsInitializer','narrow-normal', ...
        'Name','de1_conv2');
    reluLayer('Name','de1_relu2');
    convolution3dLayer(1,numLabels,'Name','convLast');
    softmaxLayer('Name','softmax');
    dicePixelClassification3dLayer('output');
];

% 奺憌傪楢寢
layers = [layers; decoderModule4];
lgraph = layerGraph(layers);
lgraph = addLayers(lgraph,decoderModule3);
lgraph = addLayers(lgraph,decoderModule2);
lgraph = addLayers(lgraph,decoderModuleFinal);

concat1 = concatenationLayer(4,2,'Name','concat1');
lgraph = addLayers(lgraph,concat1);
lgraph = connectLayers(lgraph,'en1_relu2','concat1/in1');
lgraph = connectLayers(lgraph,'de2_transconv','concat1/in2');
lgraph = connectLayers(lgraph,'concat1/out','de1_conv1');

concat2 = concatenationLayer(4,2,'Name','concat2');
lgraph = addLayers(lgraph,concat2);
lgraph = connectLayers(lgraph,'en2_relu2','concat2/in1');
lgraph = connectLayers(lgraph,'de3_transconv','concat2/in2');
lgraph = connectLayers(lgraph,'concat2/out','de2_conv1');

concat3 = concatenationLayer(4,2,'Name','concat3');
lgraph = addLayers(lgraph,concat3);
lgraph = connectLayers(lgraph,'en3_relu2','concat3/in1');
lgraph = connectLayers(lgraph,'de4_transconv','concat3/in2');
lgraph = connectLayers(lgraph,'concat3/out','de3_conv1'); 

lgraph = createUnet3d(inputSize);

%% 儗僀儎乕傪壜帇壔
analyzeNetwork(lgraph)

%% 妛廗僆僾僔儑儞傪巜掕
options = trainingOptions('adam', ...
    'MaxEpochs',100, ...
    'InitialLearnRate',5e-4, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',5, ...
    'LearnRateDropFactor',0.95, ...
    'ValidationData',dsVal, ...
    'ValidationFrequency',400, ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'MiniBatchSize',miniBatchSize);
%% 妛廗嵪傒儌僨儖偲僒儞僾儖僨乕僞僙僢僩傪僟僂儞儘乕僪
trained3DUnet_url = 'https://www.mathworks.com/supportfiles/vision/data/brainTumor3DUNet.mat';
sampleData_url = 'https://www.mathworks.com/supportfiles/vision/data/sampleBraTSTestSet.tar.gz';

imageDir = fullfile(tempdir,'BraTS');
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end
downloadTrained3DUnetSampleData(trained3DUnet_url,sampleData_url,imageDir);
%% 3D U-Net傪妛廗
% NVIDIA(R) Titan X偱60帪娫埲忋偐偐傞
doTraining = false;
if doTraining
    modelDateTime = datestr(now,'dd-mmm-yyyy-HH-MM-SS');
    [net,info] = trainNetwork(dsTrain,lgraph,options);
    save(['trained3DUNet-' modelDateTime '-Epoch-' num2str(maxEpochs) '.mat'],'net');
else
    load(fullfile(imageDir,'trained3DUNet','brainTumor3DUNet.mat'));
end

%% 僥僗僩僨乕僞偱僙僌儊儞僥乕僔儑儞傪幚峴
useFullTestSet = false;
if useFullTestSet
    volLocTest = fullfile(preprocessDataLoc,'imagesTest');
    lblLocTest = fullfile(preprocessDataLoc,'labelsTest');
else
    volLocTest = fullfile(imageDir,'sampleBraTSTestSet','imagesTest');
    lblLocTest = fullfile(imageDir,'sampleBraTSTestSet','labelsTest');
    classNames = ["background","tumor"];
    pixelLabelID = [0 1];
end
%% 128x128x128偱拞怱晹傪愗傝庢傞
windowSize = [128 128 128];
volReader = @(x) centerCropMatReader(x,windowSize);
labelReader = @(x) centerCropMatReader(x,windowSize);
voldsTest = imageDatastore(volLocTest, ...
    'FileExtensions','.mat','ReadFcn',volReader);
pxdsTest = pixelLabelDatastore(lblLocTest,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',labelReader);

%% 僙僌儊儞僥乕僔儑儞傪幚峴
% 僥僗僩夋憸偲恀抣傪僙儖攝楍偵奿擺
% 僙僌儊儞僥乕僔儑儞儔儀儖偵懳偟偰屻張棟(儊僨傿傾儞僼傿儖僞)
id=1;
while hasdata(voldsTest)
    disp(['Processing test volume ' num2str(id)])
    
    groundTruthLabels{id} = read(pxdsTest);
    
    vol{id} = read(voldsTest);
    tempSeg = semanticseg(vol{id},net);

    % Get the non-brain region mask from the test image.
    volMask = vol{id}(:,:,:,1)==0;
    % Set the non-brain region of the predicted label as background.
    tempSeg(volMask) = classNames(1);
    % Perform median filtering on the predicted label.
    tempSeg = medfilt3(uint8(tempSeg)-1);
    % Cast the filtered label to categorial.
    tempSeg = categorical(tempSeg,pixelLabelID,classNames);
    predictedLabels{id} = tempSeg;
    id=id+1;
end

%% 恀抣偲僱僢僩儚乕僋偺梊應抣傪斾妑
volId = 2;
vol3d = vol{volId}(:,:,:,1);
zID = size(vol3d,3)/2;
zSliceGT = labeloverlay(vol3d(:,:,zID),groundTruthLabels{volId}(:,:,zID));
zSlicePred = labeloverlay(vol3d(:,:,zID),predictedLabels{volId}(:,:,zID));

figure
title('Labeled Ground Truth (Left) vs. Network Prediction (Right)')
montage({zSliceGT;zSlicePred},'Size',[1 2],'BorderSize',5) 

figure
h1 = labelvolshow(groundTruthLabels{volId},vol3d);
h1.LabelVisibility(1) = 0;
h1.VolumeThreshold = 0.68;

figure
h2 = labelvolshow(predictedLabels{volId},vol3d);
h2.LabelVisibility(1) = 0;
h2.VolumeThreshold = 0.68;

%% 椫愗傝偵偟偰壜帇壔
diceResult = zeros(length(voldsTest.Files),2);
for j = 1:length(vol)
    diceResult(j,:) = dice(groundTruthLabels{j},predictedLabels{j});
end

meanDiceBackground = mean(diceResult(:,1));
disp(['Average Dice score of background across ',num2str(j), ...
    ' test volumes = ',num2str(meanDiceBackground)])
meanDiceTumor = mean(diceResult(:,2));
disp(['Average Dice score of tumor across ',num2str(j), ...
    ' test volumes = ',num2str(meanDiceTumor)])
%% 儃僢僋僗僾儘僢僩偱専弌寢壥偺昡壙抣傪僾僢僩
createBoxplot = false;
if createBoxplot
    figure
    boxplot(diceResult)
    title('Test Set Dice Accuracy')
    xticklabels(classNames)
    ylabel('Dice Coefficient')
end


%%
% _Copyright 2019 The MathWorks, Inc._