www.gusucode.com > IPCV_Eval_Kit_R2019a_0ce6858工具箱matlab程序源码 > IPCV_Eval_Kit_R2019a_0ce6858/code/demo_files/I5_06_4_1_semanticSeg_3DUnet.m
%% 僨傿乕僾儔乕僯儞僌丗3D U-Net僙僌儊儞僥乕僔儑儞 %% 弶婜壔 clear; close all ;clc; rng('default'); %% 僒億乕僩娭悢偵僷僗傪捠偡 addpath(fullfile(matlabroot,'examples','deeplearning_shared','main')); %% BraTS僨乕僞僙僢僩 % 壓婰偺乽Download Data乿偐傜乽Task01_BrainTumour.tar乿傪僟僂儞儘乕僪偟丄夝搥偟偰僼僅儖僟偵抲偔丅 % http://medicaldecathlon.com/ imageDir = fullfile(pwd,'BraTS'); if ~exist(imageDir,'dir') error('http://medicaldecathlon.com/ 偐傜 Task01_BrainTumour.tar傪僟僂儞儘乕僪偟丄夝搥偟偰偔偩偝偄丅'); end %% 慜張棟(30暘傎偳偐偐傞) sourceDataLoc = [imageDir filesep 'Task01_BrainTumour']; preprocessDataLoc = fullfile(tempdir,'BraTS','preprocessedDataset'); preprocessBraTSdataset(preprocessDataLoc,sourceDataLoc); %% 妛廗傗専徹偺偨傔偺儔儞僟儉僷僢僠拪弌僨乕僞僗僩傾傪嶌惉 volReader = @(x) matRead(x); volLoc = fullfile(preprocessDataLoc,'imagesTr'); volds = imageDatastore(volLoc, ... 'FileExtensions','.mat','ReadFcn',volReader); labelReader = @(x) matRead(x); lblLoc = fullfile(preprocessDataLoc,'labelsTr'); classNames = ["background","tumor"]; pixelLabelID = [0 1]; pxds = pixelLabelDatastore(lblLoc,classNames,pixelLabelID, ... 'FileExtensions','.mat','ReadFcn',labelReader); %% 1偮偺儃儕儏乕儉偲儔儀儖傪妋擣 volume = preview(volds); label = preview(pxds); figure h = labelvolshow(label,volume(:,:,:,1)); h.LabelVisibility(1) = 0; %% 妛廗偺愝掕 patchSize = [64 64 64]; patchPerImage = 16; miniBatchSize = 8; patchds = randomPatchExtractionDatastore(volds,pxds,patchSize, ... 'PatchesPerImage',patchPerImage); patchds.MiniBatchSize = miniBatchSize; %% 僨乕僞偺僆乕僌儊儞僥乕僔儑儞(儔儞僟儉偵夞揮偲斀揮) dsTrain = transform(patchds,@augment3dPatch); %% randomPatchExtrationDatastore傪嶌惉 volLocVal = fullfile(preprocessDataLoc,'imagesVal'); voldsVal = imageDatastore(volLocVal, ... 'FileExtensions','.mat','ReadFcn',volReader); lblLocVal = fullfile(preprocessDataLoc,'labelsVal'); pxdsVal = pixelLabelDatastore(lblLocVal,classNames,pixelLabelID, ... 'FileExtensions','.mat','ReadFcn',labelReader); dsVal = randomPatchExtractionDatastore(voldsVal,pxdsVal,patchSize, ... 'PatchesPerImage',patchPerImage); dsVal.MiniBatchSize = miniBatchSize; %% 3-D U-Net儗僀儎乕傪掕媊 inputSize = [64 64 64 4]; inputLayer = image3dInputLayer(inputSize,'Normalization','none','Name','input'); % 僄儞僐乕僟乕 numFiltersEncoder = [ 32 64; 64 128; 128 256]; layers = [inputLayer]; for module = 1:3 modtag = num2str(module); encoderModule = [ convolution3dLayer(3,numFiltersEncoder(module,1), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name',['en',modtag,'_conv1']); batchNormalizationLayer('Name',['en',modtag,'_bn']); reluLayer('Name',['en',modtag,'_relu1']); convolution3dLayer(3,numFiltersEncoder(module,2), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name',['en',modtag,'_conv2']); reluLayer('Name',['en',modtag,'_relu2']); maxPooling3dLayer(2,'Stride',2,'Padding','same', ... 'Name',['en',modtag,'_maxpool']); ]; layers = [layers; encoderModule]; end % 僨僐乕僟乕 numFiltersDecoder = [ 256 512; 256 256; 128 128; 64 64]; decoderModule4 = [ convolution3dLayer(3,numFiltersDecoder(1,1), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de4_conv1'); reluLayer('Name','de4_relu1'); convolution3dLayer(3,numFiltersDecoder(1,2), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de4_conv2'); reluLayer('Name','de4_relu2'); transposedConv3dLayer(2,numFiltersDecoder(1,2),'Stride',2, ... 'Name','de4_transconv'); ]; decoderModule3 = [ convolution3dLayer(3,numFiltersDecoder(2,1), ... 'Padding','same','WeightsInitializer','narrow-normal', .... 'Name','de3_conv1'); reluLayer('Name','de3_relu1'); convolution3dLayer(3,numFiltersDecoder(2,2), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de3_conv2'); reluLayer('Name','de3_relu2'); transposedConv3dLayer(2,numFiltersDecoder(2,2),'Stride',2, ... 'Name','de3_transconv'); ]; decoderModule2 = [ convolution3dLayer(3,numFiltersDecoder(3,1), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de2_conv1'); reluLayer('Name','de2_relu1'); convolution3dLayer(3,numFiltersDecoder(3,2), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de2_conv2'); reluLayer('Name','de2_relu2'); transposedConv3dLayer(2,numFiltersDecoder(3,2),'Stride',2, ... 'Name','de2_transconv'); ]; % 嵟廔偺僨僐乕僟乕 numLabels = 2; decoderModuleFinal = [ convolution3dLayer(3,numFiltersDecoder(4,1), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de1_conv1'); reluLayer('Name','de1_relu1'); convolution3dLayer(3,numFiltersDecoder(4,2), ... 'Padding','same','WeightsInitializer','narrow-normal', ... 'Name','de1_conv2'); reluLayer('Name','de1_relu2'); convolution3dLayer(1,numLabels,'Name','convLast'); softmaxLayer('Name','softmax'); dicePixelClassification3dLayer('output'); ]; % 奺憌傪楢寢 layers = [layers; decoderModule4]; lgraph = layerGraph(layers); lgraph = addLayers(lgraph,decoderModule3); lgraph = addLayers(lgraph,decoderModule2); lgraph = addLayers(lgraph,decoderModuleFinal); concat1 = concatenationLayer(4,2,'Name','concat1'); lgraph = addLayers(lgraph,concat1); lgraph = connectLayers(lgraph,'en1_relu2','concat1/in1'); lgraph = connectLayers(lgraph,'de2_transconv','concat1/in2'); lgraph = connectLayers(lgraph,'concat1/out','de1_conv1'); concat2 = concatenationLayer(4,2,'Name','concat2'); lgraph = addLayers(lgraph,concat2); lgraph = connectLayers(lgraph,'en2_relu2','concat2/in1'); lgraph = connectLayers(lgraph,'de3_transconv','concat2/in2'); lgraph = connectLayers(lgraph,'concat2/out','de2_conv1'); concat3 = concatenationLayer(4,2,'Name','concat3'); lgraph = addLayers(lgraph,concat3); lgraph = connectLayers(lgraph,'en3_relu2','concat3/in1'); lgraph = connectLayers(lgraph,'de4_transconv','concat3/in2'); lgraph = connectLayers(lgraph,'concat3/out','de3_conv1'); lgraph = createUnet3d(inputSize); %% 儗僀儎乕傪壜帇壔 analyzeNetwork(lgraph) %% 妛廗僆僾僔儑儞傪巜掕 options = trainingOptions('adam', ... 'MaxEpochs',100, ... 'InitialLearnRate',5e-4, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropPeriod',5, ... 'LearnRateDropFactor',0.95, ... 'ValidationData',dsVal, ... 'ValidationFrequency',400, ... 'Plots','training-progress', ... 'Verbose',false, ... 'MiniBatchSize',miniBatchSize); %% 妛廗嵪傒儌僨儖偲僒儞僾儖僨乕僞僙僢僩傪僟僂儞儘乕僪 trained3DUnet_url = 'https://www.mathworks.com/supportfiles/vision/data/brainTumor3DUNet.mat'; sampleData_url = 'https://www.mathworks.com/supportfiles/vision/data/sampleBraTSTestSet.tar.gz'; imageDir = fullfile(tempdir,'BraTS'); if ~exist(imageDir,'dir') mkdir(imageDir); end downloadTrained3DUnetSampleData(trained3DUnet_url,sampleData_url,imageDir); %% 3D U-Net傪妛廗 % NVIDIA(R) Titan X偱60帪娫埲忋偐偐傞 doTraining = false; if doTraining modelDateTime = datestr(now,'dd-mmm-yyyy-HH-MM-SS'); [net,info] = trainNetwork(dsTrain,lgraph,options); save(['trained3DUNet-' modelDateTime '-Epoch-' num2str(maxEpochs) '.mat'],'net'); else load(fullfile(imageDir,'trained3DUNet','brainTumor3DUNet.mat')); end %% 僥僗僩僨乕僞偱僙僌儊儞僥乕僔儑儞傪幚峴 useFullTestSet = false; if useFullTestSet volLocTest = fullfile(preprocessDataLoc,'imagesTest'); lblLocTest = fullfile(preprocessDataLoc,'labelsTest'); else volLocTest = fullfile(imageDir,'sampleBraTSTestSet','imagesTest'); lblLocTest = fullfile(imageDir,'sampleBraTSTestSet','labelsTest'); classNames = ["background","tumor"]; pixelLabelID = [0 1]; end %% 128x128x128偱拞怱晹傪愗傝庢傞 windowSize = [128 128 128]; volReader = @(x) centerCropMatReader(x,windowSize); labelReader = @(x) centerCropMatReader(x,windowSize); voldsTest = imageDatastore(volLocTest, ... 'FileExtensions','.mat','ReadFcn',volReader); pxdsTest = pixelLabelDatastore(lblLocTest,classNames,pixelLabelID, ... 'FileExtensions','.mat','ReadFcn',labelReader); %% 僙僌儊儞僥乕僔儑儞傪幚峴 % 僥僗僩夋憸偲恀抣傪僙儖攝楍偵奿擺 % 僙僌儊儞僥乕僔儑儞儔儀儖偵懳偟偰屻張棟(儊僨傿傾儞僼傿儖僞) id=1; while hasdata(voldsTest) disp(['Processing test volume ' num2str(id)]) groundTruthLabels{id} = read(pxdsTest); vol{id} = read(voldsTest); tempSeg = semanticseg(vol{id},net); % Get the non-brain region mask from the test image. volMask = vol{id}(:,:,:,1)==0; % Set the non-brain region of the predicted label as background. tempSeg(volMask) = classNames(1); % Perform median filtering on the predicted label. tempSeg = medfilt3(uint8(tempSeg)-1); % Cast the filtered label to categorial. tempSeg = categorical(tempSeg,pixelLabelID,classNames); predictedLabels{id} = tempSeg; id=id+1; end %% 恀抣偲僱僢僩儚乕僋偺梊應抣傪斾妑 volId = 2; vol3d = vol{volId}(:,:,:,1); zID = size(vol3d,3)/2; zSliceGT = labeloverlay(vol3d(:,:,zID),groundTruthLabels{volId}(:,:,zID)); zSlicePred = labeloverlay(vol3d(:,:,zID),predictedLabels{volId}(:,:,zID)); figure title('Labeled Ground Truth (Left) vs. Network Prediction (Right)') montage({zSliceGT;zSlicePred},'Size',[1 2],'BorderSize',5) figure h1 = labelvolshow(groundTruthLabels{volId},vol3d); h1.LabelVisibility(1) = 0; h1.VolumeThreshold = 0.68; figure h2 = labelvolshow(predictedLabels{volId},vol3d); h2.LabelVisibility(1) = 0; h2.VolumeThreshold = 0.68; %% 椫愗傝偵偟偰壜帇壔 diceResult = zeros(length(voldsTest.Files),2); for j = 1:length(vol) diceResult(j,:) = dice(groundTruthLabels{j},predictedLabels{j}); end meanDiceBackground = mean(diceResult(:,1)); disp(['Average Dice score of background across ',num2str(j), ... ' test volumes = ',num2str(meanDiceBackground)]) meanDiceTumor = mean(diceResult(:,2)); disp(['Average Dice score of tumor across ',num2str(j), ... ' test volumes = ',num2str(meanDiceTumor)]) %% 儃僢僋僗僾儘僢僩偱専弌寢壥偺昡壙抣傪僾僢僩 createBoxplot = false; if createBoxplot figure boxplot(diceResult) title('Test Set Dice Accuracy') xticklabels(classNames) ylabel('Dice Coefficient') end %% % _Copyright 2019 The MathWorks, Inc._