A two-stage 3D Unet framework for multi-class segmentation on full resolution image
Chengjia Wang, Tom MacGillivray, Gillian Macnaught, Guang Yang, David Newby
AA two-stage 3D Unet framework for multi-classsegmentation on full resolution image
Chengjia Wang , (cid:63) , Tom MacGillivray , Gillian Macnaught , , Guang Yang ,and David Newby , BHF Centre for Cadiovascular Science, University of Edinburgh, Edinburgh, UK [email protected] Edinburgh Imaging Facility QMRI, University of Edinburgh, Edinburgh, UK National Heart & Lung Institute, Imperial College London, London, UK
Abstract.
Deep convolutional neural networks (CNNs) have been in-tensively used for multi-class segmentation of data from different modal-ities and achieved state-of-the-art performances. However, a commonproblem when dealing with large, high resolution 3D data is that thevolumes input into the deep CNNs has to be either cropped or down-sampled due to limited memory capacity of computing devices. These op-erations lead to loss of resolution and increment of class imbalance in theinput data batches, which can downgrade the performances of segmen-tation algorithms. Inspired by the architecture of image super-resolutionCNN (SRCNN) and self-normalization network (SNN), we developed atwo-stage modified Unet framework that simultaneously learns to de-tect a ROI within the full volume and to classify voxels without losingthe original resolution. Experiments on a variety of multi-modal volumesdemonstrated that, when trained with a simply weighted dice coefficientsand our customized learning procedure, this framework shows better seg-mentation performances than state-of-the-art Deep CNNs with advancedsimilarity metrics.
Keywords:
Image segmentation, Convolutional Neural Networks, Highresolution, Cardiac CT/MR
Segmenting the whole heart structures from CT and MRI data is a necessarystep for pre-precedural planing of cardiovascular diseases. Although it is themost reliable approach, manual segmentation is very labor-intensive and sub-ject to user variability [1]. High anatomical and signal intensity variations makeautomatic whole heart segmentation a challenging task. Previous methods thatseparately segment specific anatomic structure [2,3] are often based on active de-formation models. Others perform multi-class segmentation among which atlas-based methods [4] play an important role. Active deformation models can suffer (cid:63)
This work is funded by BHF Centre of Cardiovascular Science and MICCAI 2017Multi-Modality Whole Heart Segmentation (MM-WHS) challeng. a r X i v : . [ c s . C V ] A p r Authors Suppressed Due to Excessive Length from limited ability to decouple pose variation [5], and the main disadvantageof atlas-based methods is requiring complex procedures to construct the atlasor non-rigid registration [6]. Recently, due to the development of deep learning,deep convolutional neural networks (DCNNs) and probabilistic graphic models(PGMs), especially U-net-like models [7], have been vastly used for cardiac seg-mentation and achieved start-of-the-art results. The purpose of this study is todevelop a DCNN which can perform multi-class segmentation on full-resolutionvolumetric CT and MR data with no post-prediction resampling or subvolume-fusion operations. This is necessary due to the loss of information introduced byinterpolation and extra complexity of post-processing.The original U-Net is entirely an 2D architecture. So are most DCNN-basedfull heart segmentation methods [8] [9]. To process volumetric data, some modelstakes three perpendicular 2D slices as input and fuse the multi-view informa-tion abstraction for 3D segmentation [10,11]. 3D U-Net-like DCNNs, where the2D operations were replaced by their 3D counterparts [12,13], were adopted indifferent applications. Very limited numbers of works have applied volumetricU-Nets to 3D whole heart CT or MRI data for multi-class segmentation [14].Due to limited memory capacity of GPUs, these 3D DCNN methods have toeither make predictions on down-sampled volumes, which leads to loss of reso-lution in the final results, or process subvolumes of the data followed by extrapost-processing step to merge the overlapped predictions as in [14]. Methodsthat preserve the original data resolution often use a relatively shallow U-Netarchitecture or have just been tested on low-resolution MR images.In this paper, we propose a two-stage DCNN framework which is built byconcatenating two U-Net-like networks. A new multi-stage learning pipeline wasadopted to the training process. This framework with auxiliary outputs segment3D CT and MR data through dynamic ROI-extraction. Experiments with limitedtraining data have demonstrated that our model outperformed well-trained 3DU-Nets with necessary post-processing steps.
The proposed DCNN model classify all the voxels within an axial slice basedon a pre-defined neighborhood of axial slices around it. As shown in Fig. 1 thecomplete model consists of two concatenated modified U-Nets. Generally followarchitecture of the original 2D U-Net, the basic block of this model consistsof two convolutional layers, each followed followed by nonlinear activation anda 2 × × itle Suppressed Due to Excessive Length 3 Fig. 1.
The concatenated U-Net architecture proposed in this work. The nonlinearactivation and pooling layer within each U-Net block are not shown for demonstrationpurpose.
The first network (
Net1 in Fig. 1) use down-sampled 3D volume to makea coarse prediction of the voxel labels. The produced label volume is then re-sampled to the original resolution. To capture information from larger effectivereceptive field, we use slightly dilated 5 × × n th block of the con-tracting path, the dilation rate of the convolutional kernel is 2 n . This pattern isreversed in the expansive path. Each convolutional layer is followed by a recti-fied linear unit (ReLU), and a dropout layer with a 0.2 dropout rate is attachedto each U-Net block. In the test phase, a dynamic-tile layer is introduced be-tween Net1 and
Net2 to crop out a region-of-interest (ROI) from both the inputand output volume of
Net1 . This layer is removed when performing end-to-endtraining to simplify implementation.The architecture of
Net2 is inspired by the deep Super-Resolution Convolu-tional Neural Network (SRCNN) [15] with skip connections and recursive units[16]. The input of this network is a two-channel 4D volume composed by theoutput of
Net1 and the original data. The convolutional kernel size in the con-tracting path is 3 × ×
3, and 5 × × N et
1, the size of the 3D pooling kernels in the contracting path is 2 × × × × ( K −
1) convolutional kernels are introduced before the expansive path,where K is the number of neighboring slices used to label one single axial slice.No zero-paddings are used so that every K input slices will generate one singleaxial feature map. Furthermore, K should always be an odd number to preventgenerating labels for interpolated slices. The following layers before the outputof Net2 perform 2D convolutions and pooling.
Authors Suppressed Due to Excessive Length
Table 1.
Perposes and loss functions of each step in the training processStep Input Purpose Loss1 full volumetric data foreground localization L ROI L ROI + L L + L L The two U-Net-like DCNNs of the proposed model are flexible enough to betrained either separately or end-to-end with changing sizes of input data. In thisstudy, we combined both approaches into a four-step training procedure. At thebeginning,
N et
N et
N et
Dice Score
A commonly used similarity metric for single-class segmentationis soft Dice score. Let p in,c denote the probability that a voxel belongs to class c, c ∈ { , · · · , C } , given by the softmax layer of N eti , and t n,c ∈ { , } representthe ground truth one-hot label. The soft Dice score can be defined by: S ic = 2 (cid:80) N c n t n,c p in,c + (cid:15) (cid:80) N c n (cid:0) t n,c + p in,c (cid:1) + (cid:15) , (1)where N c is number of voxels labeled as class c and (cid:15) is a smooth factor. Toperform multi-class segmentation, we just define our loss function using weightedDice scores weighted by voxel counts for simplicity: L i = 1 − (cid:88) Cc S ic N c . (2)But nothing stops using a more sophisticated loss functions as shown in [17]. Indifferent steps of the training process, losses of the two nets are combined fordifferent stage targets. Foreground Localization
In this step,
Net1 is trained with full volumetricdata to roughly localize the foreground, or a soft ROI, which is the segmentedobject. Other contents in the data are considered as background. Parametersof
N et × × itle Suppressed Due to Excessive Length 5 L ROI computed from
N et S ROI = 2 (cid:80) N n (1 − t n, ) (cid:0) − p n, (cid:1) + (cid:15) (cid:80) N n (cid:0) − t n, − p n, (cid:1) + (cid:15) , (3)where N is the number of the background points as the background is definedas class 0. The corresponded foreground loss is: L ROI = 1 − S ROI N . (4)We use reversed label to calculate foreground score rather than the Dice scoreof background to reduce the imballance introduced by large background. N et L ROI for quickly specify the foreground of theobject.
Multi-class Segmentation
After pre-training with L ROI , L + L ROI is usedas the loss for coarse multi-class segmentation in the second step, where L isDice loss defined by equation 2. In this step, N et
N et
N et
2) is trained end-to-end with the loss L + L to evolve both coarse 3Dsegmentation and the fine-level axial slice segmentation. As both networks arefully convolutional, the sampling strategy of the input data keep the same withstep 2. In the final step, inputs of the framework are subvolumes, each consist of K complete axial slices. The output of N et K +12 thslice of a input subvolume. In this step, the parameters of N et
N et L . Because the framework is mostly trained with subvolumes of the 3D data ex-cept the first step, we use a hierarchical sampling strategy similar with [18].Each batches are generated from a small number of data. Dealing with highlyimbalanced data, we first select the class that the central voxel of the sampledsubvolume from a uniform distribution. Once the label of the central voxel isfixed, the subvolume is generated by randomly pick its centre from all voxelslabeled as the selected class. In this way, the probabilities that the central voxelbelongs to any of the classes should be C +1 . For optimization, we use Adamoptimizer with initial learning rate 0 . { , , } . In the final training step, we set K = 9,which means N et
Authors Suppressed Due to Excessive Length
The MICCAI 2017 Multi-Modality Whole Heart Segmentation (MM-WHS) chal-lenge recently benchmarks existing whole heart segmentation algorithms. Fortraining purpose, the challenge provides 40 volumes (20 cardiac CT and 20 car-diac MR) in the real clinical environment. The data were acquired with differentscanners, which leads to varying voxel sizes, resolutions and imaging qualities.An extra 80 testing images are available from the challenge for one-shot val-idation. In this dataset, anatomical structures which are manually delineatedinclude, the left ventricle blood cavity (LV), the myocardium of the left ventri-cle (Myo), the right ventricle blood cavity (RV), the left atrium blood cavity(LA), the right atrium blood cavity (RA), the ascending aorta (AA) and thepulmonary artery (PA).One may argue that the proposed framework can be trained directly usingthe final training step. To demonstrate their effectivenes, we trained our modelby omitting one of the first three steps, and visually assessed the segmentationresults which can be found in the next section. The four training steps don’thave to be kept going until converged except the final step. In this step-wiseexperiment, all results were obtained with 200 epochs in each step, and eachepoch includes 16 iterations of backpropogation. The whole training processcontains 12800 iterations in total.Besides qualitatively evaluating the visualized segmentation, for each modal-ity, we use 15 volumes for 3-fold cross-validation training, and 5 volumes forvalidation. To compare our framework with state-of-the-art U-Net-based mod-els, we trained two 3D U-Nets for each modality which predict on data resampledto resolution of 2 × × mm . Then the output volumes are resampled to theoriginal resolution using 2nd order BSpline interpolation. Intensities of all imagesare rescaled to [ − ,
1] with no further preprocessing. Three metrics were usedto assess segmentation quality for each class: binary Dice score (Dice), binaryJaccard index (Jaccard). After step-wise experiments, we retrain the networksend-to-end from scratch and submit the results to MM-WHS challenge for fur-ther evaluation using their test data.
Examples of viusalized segmentation reulsts are shown in Fig. 2. Omitting theforeground localization step in the first training step may lead to misclassificationof the background voxels, as shown in the top row of Fig. 2. The middle rowshows that without the coarse segmentation (second step) the model failed tolabel left atrium, and produced inhomogeneious segmentation for aorta whenskipping the joint training of
N et
N et
N et
N et
N et
N et itle Suppressed Due to Excessive Length 7
Fig. 2.
Visualization of segmentation results overlapped with the original data:Ground-truth segmentations are shown on the left; the middle column shows the resultsobtained by omitting the first, second and third training step (from top to bottom);on the right is the results obtained with the proposed training process. relatively lower resolution, the volume size changed less after resampling.
N et
N et
N et
N et
Authors Suppressed Due to Excessive Length
Table 2.
Comparison of CT segmentation results obtained by 3D U-Net, and ourproposed
Net
Net
N3D U-Net
Dice 0.6451 0.8301 0.7873 0.7768 0.6784 0.8306 0.7123Jaccard 0.4889 0.7126 0.6572 0.6397 0.5217 0.7143 0.5560
Net1
Dice 0.6774 0.8107 0.8136 0.8118 0.7997 0.8889 0.8086Jaccard 0.5399 0.6979 0.6977 0.6908 0.6717 0.8030 0.6802
Net2
Dice
Jaccard
Comparison of MR segmentation results obtained by 3D U-Net, and ourproposed
Net
Net
N3D U-Net
Dice 0.8296 0.9141
Net1
Dice 0.8811 0.9367 0.9131 0.9334 0.8572 0.8750 0.9204Jaccard 0.7877 0.8813 0.8430 0.8757 0.7694 0.7833 0.8528
Net2
Dice
Jaccard
Quantative validation results obtained using MM-WHS test data, evaluationmetrics include: average of Dice score (Dice), average of Jaccard score (Jaccard), andAverage surface distance (ASD).Modality Metrics LV Myo RV LA RA AA PA WH CT Dice 0.7995 0.7293 0.7857 0.9044 0.7936 0.8735 0.6482 0.8060Jaccard 0.6999 0.6091 0.6841 0.8285 0.6906 0.8113 0.5169 0.6970ASD 4.4067 5.4854 4.8816 1.3978 4.1707 3.7898 6.0041 4.1971 MR Dice 0.8632 0.7443 0.8485 0.8524 0.8396 0.8236 0.7876 0.8323Jaccard 0.7693 0.6049 0.7469 0.7483 0.7404 0.7095 0.6657 0.7201ASD 1.9916 2.3106 1.8925 1.7081 2.7566 4.2610 2.9296 2.4718 model. Notice that the purpose of this study is to generate the model gave betterperformance than state-of-the-art U-Net when segmenting high resolution data.This has been shown in the experiment described above.
In this paper, we described a two-stage U-Net-like framework for multi-classsegmentation. Unlike other U-Net based 3D data segmentation DCNN, the pro-posed method can directly make prediction for data with original resolution dueto its SRCNN-inspired architecture. A novel 4-step training procedure were ap- itle Suppressed Due to Excessive Length 9 plied to the framework. Validated using data from MM-WHS2017 competition, itproduced more accurate multi-class segmentation results than state-of-the-art U-Net. With much less training iterations and without any further post-processing,our method achieved segmentation accuracies comparable to the winner of MM-WHS2017 competition.
References
1. Pace, D.F., Dalca, A.V., Geva, T., Powell, A.J., Moghari, M.H., Golland, P.: In-teractive whole-heart segmentation in congenital heart disease. (2015) 80–882. Petitjean, C., Zuluaga, M.A., Bai, W., Dacher, J.N., Grosgeorge, D., Caudron, J.,Ruan, S., Ayed, I.B., Cardoso, M.J., Chen, H.C., et al.: Right ventricle segmen-tation from cardiac mri: a collation study. Medical image analysis (1) (2015)187–2023. Arrieta, C., Uribe, S., Sing-Long, C., Hurtado, D., Andia, M., Irarrazaval, P., Tejos,C.: Simultaneous left and right ventricle segmentation using topology preservinglevel sets. Biomedical Signal Processing and Control (2017) 88–954. Zhuang, X., Shen, J.: Multi-scale patch and multi-modality atlases for whole heartsegmentation of mri. Medical image analysis (2016) 77–875. Gonzalez-Mora, J., De la Torre, F., Murthi, R., Guil, N., Zapata, E.L.: Bilinearactive appearance models. In: Computer Vision, 2007. ICCV 2007. IEEE 11thInternational Conference on, IEEE (2007) 1–86. Marsland, S., Twining, C.J., Taylor, C.J.: Groupwise non-rigid registration us-ing polyharmonic clamped-plate splines. In: International Conference on MedicalImage Computing and Computer-Assisted Intervention, Springer (2003) 771–7797. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedi-cal image segmentation. In: International Conference on Medical image computingand computer-assisted intervention, Springer (2015) 234–2418. Wolterink, J.M., Leiner, T., Viergever, M.A., Iˇsgum, I.: Dilated convolutionalneural networks for cardiovascular mr segmentation in congenital heart disease. In:Reconstruction, Segmentation, and Analysis of Medical Images. Springer (2016)95–1029. Moeskops, P., Wolterink, J.M., van der Velden, B.H., Gilhuijs, K.G., Leiner, T.,Viergever, M.A., Iˇsgum, I.: Deep learning for multi-task medical image segmenta-tion in multiple modalities. In: International Conference on Medical Image Com-puting and Computer-Assisted Intervention, Springer (2016) 478–48610. Mortazi, A., Burt, J., Bagci, U.: Multi-planar deep segmentation networks forcardiac substructures from mri and ct. arXiv preprint arXiv:1708.00983 (2017)11. Luo, G., Dong, S., Wang, K., Zuo, W., Cao, S., Zhang, H.: Multi-views fusion cnnfor left ventricular volumes estimation on cardiac mr images. IEEE Transactionson Biomedical Engineering (2017)12. C¸ i¸cek, ¨O., Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3d u-net:learning dense volumetric segmentation from sparse annotation. In: InternationalConference on Medical Image Computing and Computer-Assisted Intervention,Springer (2016) 424–43213. Roth, H.R., Oda, H., Hayashi, Y., Oda, M., Shimizu, N., Fujiwara, M., Misawa,K., Mori, K.: Hierarchical 3d fully convolutional networks for multi-organ segmen-tation. arXiv preprint arXiv:1704.06382 (2017)0 Authors Suppressed Due to Excessive Length14. Yu, L., Cheng, J.Z., Dou, Q., Yang, X., Chen, H., Qin, J., Heng, P.A.: Automatic 3dcardiovascular mr segmentation with densely-connected volumetric convnets. In:International Conference on Medical Image Computing and Computer-AssistedIntervention, Springer (2017) 287–29515. Dong, C., Loy, C.C., He, K., Tang, X.: Image super-resolution using deep convo-lutional networks. IEEE transactions on pattern analysis and machine intelligence38