Automatic Data Augmentation for 3D Medical Image Segmentation
AAutomatic Data Augmentation for 3D MedicalImage Segmentation
Ju Xu (cid:63) , Mengzhang Li , (cid:63) , and Zhanxing Zhu , Center for Data Science, Peking University, Beijing, China Canon Medical Systems, Beijing, China School of Mathematical Sciences, Peking University, Beijing, China { xuju, mcmong, zhanxing.zhu } @pku.edu.cn Abstract.
Data augmentation is an effective and universal techniquefor improving generalization performance of deep neural networks. Itcould enrich diversity of training samples that is essential in medicalimage segmentation tasks because 1) the scale of medical image dataset istypically smaller, which may increase the risk of overfitting; 2) the shapeand modality of different objects such as organs or tumors are unique,thus requiring customized data augmentation policy. However, most dataaugmentation implementations are hand-crafted and suboptimal in medi-cal image processing. To fully exploit the potential of data augmentation,we propose an efficient algorithm to automatically search for the optimalaugmentation strategies. We formulate the coupled optimization w.r.t.network weights and augmentation parameters into a differentiable formby means of stochastic relaxation. This formulation allows us to applyalternative gradient-based methods to solve it, i.e. stochastic naturalgradient method with adaptive step-size. To the best of our knowledge,it is the first time that differentiable automatic data augmentation isemployed in medical image segmentation tasks. Our numerical experi-ments demonstrate that the proposed approach significantly outperformsexisting build-in data augmentation of state-of-the-art models.
Keywords:
Medical Image Segmentation · Data Augmentation · Au-toML
In the past few years, deep neural network has achieved incredible progressin medical image segmentation tasks and promoted booming development ofcomputer assisted intervention. This has benefitted research and clinical treatmentof disease diagnosis, treatment design and prognosis evaluation [13,14]. Given thetraining data, researchers proposed various 2D/3D medical image segmentationmodels for supervised or semi-supervised tasks [8,6]. However, the performance ofdeep learning models heavily relies on large scale well-labeled data. Currently, data (cid:63)
Equal contributions. a r X i v : . [ ee ss . I V ] D ec Ju Xu, Mengzhang Li, and Zhanxing Zhu augmentation is a widely used and effective technique to increase the amount anddiversity of available data, and thus improving models’ generalization performance.In the domain of natural image processing, typical data augmentation strategiesinclude manually cropping, rotating or adding random noise to the originalimages. Besides thess ad-hoc approaches, generative models [7] and unsupervisedlearning models [15] are also employed for generating extra data. Unfortunately,those augmentation techniques might not be optimal for a specific task, and thusthe customized data augmentation strategy is required. Recently, researchersproposed to search the augmentation policy by reinforcement learning [5] ordensity matching [10], inspired by previously works of automatic machine learning(AutoML) on neural architecture search (NAS [11,12,18]).For medical image segmentation tasks, data augmentation techniques arealso used in UNet and its variants nnUNet [8], R2U-Net [2], etc. However, thesemethods are simple and hand-made, and the improvement of segmentationaccuracy is limited. In [16], the authors proposed to utilize reinforcement learningto search for augmentation strategies. However, it costs 768 GPU hours and itonly searchs the probability of each augmentation strategy in [16]. Moreover,the difference between natural and medical images such as spatial contextualcorrelation, smaller scale of dataset and unique pattern of specified organs ortumor makes the augmentation strategies adopted in natural images difficult betransferred to medical domains.In this paper, we propose an automatic data augmentation framework (ASNG)through searching the optimal augmentation policy, particularly for 3D medicalimage segmentation tasks. It’s the first automatic data augmentation work inwhole semantic segmentation filed. The contributions of our paper are as follows: – It’s the first time that we formulate the auto-augmentation problem into abi-level optimization problem and apply an approximate algorithm to solve it – The designed search space in medical image field is novel. Different fromprevious methods which searched for a fixed magnitude of operations, wesearch for an interval of magnitude – Different from previous method which searched for a fixed augmentationstrategy, the searched augmentation strategy of our method is dynamicallychanging during the training. Besides, we don’t need to retrain the targetnetwork after the searching process – Experiments demonstrate that our ASNG can indeed achieve the SOTA ofthe performance
In our method, we formulate the problem of finding the optimal augmentationpolicy as a discrete search problem. Our method consists of two components: thedesigned of search space and search algorithm. The search algorithm samplesa data augmentation policy S from the search space consisting of proposedoperations, and then decides the magnitude of the operation and the probability utomatic Data Augmentation for 3D Medical Image Segmentation 3 ! " Mini-Batch ' ( ' ) … ' * + Sample policies from , - Trainingnetwork
Update . ! /$0 Mini-Batch ' ( ' ) … - Trainingnetwork
Update - Sample policies from , - Fig. 1.
The framework of our proposed method. D train , D val represent training dataset,validation dataset, respectively. p θ is the distribution of c . of applying this operation. The framework of our method can be seen in Fig.1.We will elaborate the two components in the following. Since it is the first work for applying AutoAugment strategies in medical imagearea, we have to design the search space for our ASNG algorithm. In our searchspace, a policy consists of seven image operations to be applied in a sequentialmanner. Each image operation is associated with two hyperparameters: 1) theprobability of applying this operation, and 2) the interval of magnitude for theimage operation.The seven image operations we used in our experiments are from batchgener-ators, a pupular Python image library , including Scale, RoateX/Y/Z, Alpha(magnitude of the elastic deformation), Sigma (scale of the elastic deformation),Gamma (same as gamma correction in photos or computer monitors). In order toincrease the diversity of augmentation policies, we do not fix a specific magnitudefor an operation like previous works [5], but set an interval of magnitude foran operation. Therefore we should decide the left boundary of interval (LB)and the right one (RB). To decrease the search complexity, we discretize therange of magnitude into 11 values with uniform spacing so that we can use adiscrete search algorithm to find them. Besides the magnitude of transformationoperation, we also search for the probability of conducting these transformations,i.e. the probability of applying scale transformation, rotation, gamma transforma-tion, and elastic deformation, denoted as p scale , p rot , p gamma , p eldef , respectively.Similarly, we also discretize the probability of applying that operation into 11 https://github.com/MIC-DKFZ/batchgenerators Ju Xu, Mengzhang Li, and Zhanxing Zhu values with uniform spacing. Table 1 summarizes the range of magnitudes andpossibilities for the seven operations. Fig. 2 shows one example of augmentedimage and label based on our method, in which the image transformations arefrom the defined search space.We can easily observe that naively searching one augmentation strategybecomes a search problem with 11 possibilities. The search space is so hugethat an efficient algorithm is required, as proposed in the following. Table 1.
The range of parameters in strategies we will search.Operation LB RB Probability RangeScale [0.5, 1.0] [1.0, 1.5] p scale [0, 1]RotationX [ − π , 0] [0, − π ] p rot [0, 1]RotationY [ − π , 0] [0, − π ] p rot [0, 1]RotationZ [ − π , 0] [0, − π ] p rot [0, 1]Alpha [0, 450] [450, 900] p eldef [0, 1]Sigma [0, 7] [7, 14] p eldef [0, 1]Gamma [0.5, 1] [1, 1.5] p gamma [0, 1] We denote f ( w, c ) as the objective function, where w ∈ W are network parametersand c ∈ C are data augmentation strategies. f train and f val are the trainingand the validation loss, respectively. Both losses are determined not only by theaugmentation policy c , but also the weight w . The goal for augmentation strategysearch is to find c ∗ that minimizes the validation loss f val ( w ∗ , c ∗ ), where theweights are obtained by minimizing the training loss w ∗ = argmin w f train ( w, c ∗ ).Thus augmentation strategy search is a bi-level optimization problem, we canwrite the problem as follows:min c f val ( w ∗ ( c ) , c ) (1) s.t. w ∗ ( c ) = argmin w f train ( w, c ) (2)Solving the above problem is not easy, since we cannot obtain the gradientw.r.t. c , thus it is hard to optimize c via gradient descent. Though simple gridsearch or reinforcement learning proposed in [5] can be utilized to search for c ,the computational cost is extremely high if we evaluate the performance of every c . To this end, we propose to solve this optimization problem efficiently first bystochastic relaxation and then applying natural gradient descent [3], as describedin the following. Stochastic Relaxation
We turn the original optimization problem into anoptimization of differentiable objective J by stochastic relaxation [1]. The basic utomatic Data Augmentation for 3D Medical Image Segmentation 5 Fig. 2.
Visualization of the proposed data augmentation on certain 2D section of Task01BrainTumour dataset.
Top: original image and label,
Bottom: the image and labelgenerated by the searched data augmentation policy, in which the operations are fromTable 1. idea of stochastic relaxation is: instead of directly optimizing w.r.t c , we considera distribution p θ ( c ) over c parametrized by θ , and minimize the expected valueof the validation loss f val w.r.t θ , i.e.,min θ J ( θ ) = (cid:90) c ∈C f val ( w ∗ ( c ) , c ) p θ ( c ) dc (3) s.t. w ∗ ( c ) = argmin w f train ( w, c ) (4)The stochastic relaxation makes J differentiable w.r.t both w and θ . Thereforewe can update w and θ by gradient descent. However, the gradient ∇ w J ( w, θ )is not tractable because we can not evaluate the mean performance of c ∈ C ina closed-form way. Here we estimate the gradient by Monte-Carlo (MC) using ∇ w J ( w t , c i ) with i.i.d. samples c i ∼ p θ t ( c ) , i = 1 , . . . , N w , namely: G w ( w t , θ t ) = 1 N w N w (cid:88) i =1 ∇ w f train ( w t , c i ) (5)Now we can approximate ∇ w J ( w, θ ) with the stochastic gradient G w ( w t , θ t ), w t can be updated as w t +1 = w t − (cid:15) w G w ( w t , θ t ) , (6)where (cid:15) w is the learning rate for network parameters. Due to that the distancebetween two probability distribution is not Euclidean, updating θ directly bygradient descent like w is not appropriate. We then resort to natural gradient(NG [3]) designed for parametric probability distributions, θ t +1 = θ t − (cid:15) θ F ( θ t ) − ∇ θ J ( w, θ ) , (7) Ju Xu, Mengzhang Li, and Zhanxing Zhu where F ( θ t ) is the Fisher matrix, (cid:15) θ is the learning rate. The probability distri-bution we utilized for c ∈ C is multinomial distribution. How to calculate theFisher matrix can be seen in [1]. Similiar with [1], we utilize adaptive step-size (cid:15) θ to make the learning process faster. Monte-Carlo is also adopted to approximate ∇ θ J ( w, θ ), and then θ t +1 = θ t − (cid:15) θ F ( θ t ) − N θ N θ (cid:88) j =1 ∇ θ f val ( w t +1 , c j ) ln p θ ( c j ) (8)We summarize the procedure of our proposed approach in Algorithm 1. Algorithm 1
ASNG Input: w , θ , (cid:15) w . (cid:15) θ , N w , N θ Input:
Training dataset D train , validation dataset D val , test dataset D test .3: for i=1 to epoch do for t=1 to T do
5: Generate N w policys according to p θ t
6: Augment training data from D train with N w policys, respectively;7: Obtain the loss f train ( w t , c i ) ( i = 1 , . . . , N w ) on D train ;8: Update w t according to Equation 6, then obtain w t +1 ;9: Generate N θ policys according to p θ t ;10: for j=1 to N θ do
11: Augment training data according to policy c j ;12: Update w t to obtain ˆ w t ;13: Obtain the validation loss f val ( ˆ w t ) j on D val ;14: Restore the network parameters, ˆ w t = w t ;15: end for
16: Utilize validation loss f val ( ˆ w t ) j , policys c j ( j = 1 , . . . , N θ ) to update θ t ac-cording to equation 8;17: end for end for
19: Test the network on D test ;20: return final networks. : We conduct the proposed method on three 3D segmentation tasksused in the medical segmentation decathlon challenge (MSD ): (1) Task01 BrainTumour (484 labeled images, 3 classes), (2) Task02 Heart (20 labeled images,1 class) and (3) Task05 Prostate (32 labeled images, 2 classes). Each dataset http://medicaldecathlon.com/utomatic Data Augmentation for 3D Medical Image Segmentation 7 is collected for a specified task, their various input sizes, voxel spacings andforeground targets are suitable for demonstrating our algorithm’s generalization.We evaluate the performance with 5-fold cross validation as [4,9] since the groundtruth labels for test dataset are not publicly available. Compared Methods include 3D U-ResNet [17], SCNAS [9], nnUNet withoutdata augmentation (nnUNet NoDA) and 3D nnUNet [8] with default dataaugmentation strategy . 3D U-ResNet utilizes residual blocks and attention gates;and SCNAS is the latest neural architecture search model for 3D medical imagesegmentation, which applies a scalable gradient-based optimization to find theoptimal model architecture. The method proposed in SCNAS can’t utilized fordifferentiable autoaugmentation strategies search. In SCNAS, a mixed operationis created by adding all these operations in search space based on the importanceof each operation. However, we can’t add the transformation results of eachaugmentation strategy. There we don’t apply the proposed method of SCNAS toour augmentation strategies search tasks. Note that the code of nnUnet has alreadyimplemented random augmentation. In nnUnet, the magnitude of operation issampled from a predefined interval in every training epoch. AutoAugment [5]costs 5000 GPU hours to search for a policy. FastAutoAugment [10] needs tospilt the training dataset as K folds. However, dataset in medical image area isquite small. Training model in a small dataset will overfit. Therefore, we don’tcompare our method with AutoAugment and FastAutoAugment. The predictionresult is inferenced using a sliding window with half the patch size ensuring 50%overlapping, i.e., each voxel is inferenced at least two times at test. Implementation Details
We preprocess the data with same pipeline used in3D nnUNet. We unify the identical voxel spacing values by proper interpolationdue to different spacing values of each case, i.e. resampling them to 0.7 mm × × × × × − and 3 × − , respectively, where it is reduced by 80% if the trainingloss is not reduced over 30 epochs. Besides ASNG, the training process of otherbenchmarks would last for 500 epochs if the learning rate is larger than 10 − .Following [9] and [8], the loss function for 3D U-ResNet and SCNAS is Jaccarddistance, for nnUNet and ASNG is sum of minus Dice similarity and CrossEntropy. Considering the training time, ASNG is trained for 50, 200 and 200epochs on Brain Tumour, Heart and Prostate, respectively, with batch size of2. It takes about 10 days on one NVIDIA TITAN RTX GPU, compared with https://github.com/MIC-DKFZ/nnUNet/ Ju Xu, Mengzhang Li, and Zhanxing Zhu Fig. 3.
Results on Task05 Prostate of selected architectures.
Left:
Example of inference,green mask represents peripheral zone and red mask represents transitional zone.
Right:
The trend of loss on validation set. that one integrated nnUNet training procedure takes about 3 days. The samplingtimes T of ASNG is 2 because of the limited memory of GPU, though larger T could produce better numerical results. Our codes can be found here . Our result is shown in Table 2. Because of unavailable labels of test set andrestricted online submission times, those Auto ML models on 3D Medical ImageSegmentation tasks are all evaluated on validation set. In this paper we stillfollows this metric for fair comparison. ASNG outperforms other architecturesespecially 3D nnUNet, which is the best medical image segmentation frameworkwith default data augmentation. It should be noted that since Heart and Prostateonly have 20 and 32 labeled images, in [9] the obtained architecture of SCNASbased on the first fold of 484 labeled Brain Tumour images is transferred to Heartand Prostate tasks to avoid overfitting. Remarkably, our method, applied onlyon the basic network architecture, could still achieve best prediction accuracy.This clearly demonstrates the necessity and effectiveness of data augmentationpolicy search in 3D medical image segmentation.Figure 3 shows the example of segmentations results and validation loss w.r.t.number of epochs in the Prostate task. We can observe that our method ASNGcan produce better prediction and achieve more stable improvement duringtraining than other compared methods.In this paper [16], the proposed method utilizes reinforcement learning tosearch for augmentation strategies, which costs 768 GPU hours while our methodcosts less than 100. And their result (Dice 0.92) on task 02 is worse than ours(0.933). Besides, the first paper only searched the probability of each augmentationstrategy while our method not only search the probability but also the magnitude. https://github.com/MengzhangLI/ASNGutomatic Data Augmentation for 3D Medical Image Segmentation 9 Label Brain Tumour Heart ProstateEdema Non-Enhancing Enhancing Average Left atrium Peripheral Transitional AverageU-ResNet 79.10 58.38 77.37 71.61 91.48 48.37 79.17 63.77nnUNet NoDA 81.27 60.92 77.90 73.36 92.85 58.61 83.61 71.11nnUNet 81.68 61.29 77.97 73.65 93.21 63.14 86.53 74.84SCNAS 80.41 59.85 78.50 72.92 91.91 53.81 82.02 67.92ASNG
Table 2.
Average Dice similarity coefficients (%) for Brain tumor, Heart, and Prostate3D segmentation tasks of MSD.
We have proposed an automatic data augmentation strategy to accommodate 3Dmedical image segmentation tasks. By configuring proper search space followed bygradient-based optimization, the customized data augmentation strategy for eachtask could be obtained. The numerical results for different segmentation tasksshow it could outperform the state-of-the-art models that are widely used in thisarea. Furthermore, the proposed approach shows that, compared with searchingnetwork architectures, searching for optimal data augmentation policy is alsoimportant. As for future work, designing better search space and acceleratingthe search process can be considered.
References
1. Akimoto, Y., Shirakawa, S., Yoshinari, N., Uchida, K., Saito, S., Nishida, K.:Adaptive stochastic natural gradient method for one-shot neural architecture search.arXiv preprint arXiv:1905.08537 (2019)2. Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M., Asari, V.K.: Recurrent residualconvolutional neural network based on u-net (r2u-net) for medical image segmenta-tion. arXiv preprint arXiv:1802.06955 (2018)3. Amari, S.I.: Natural gradient works efficiently in learning. Neural computation (2), 251–276 (1998)4. Bae, W., Lee, S., Lee, Y., Park, B., Chung, M., Jung, K.H.: Resource optimizedneural architecture search for 3d medical image segmentation. In: InternationalConference on Medical Image Computing and Computer-Assisted Intervention. pp.228–236. Springer (2019)5. Cubuk, E.D., Zoph, B., Mane, D., Vasudevan, V., Le, Q.V.: Autoaugment: Learningaugmentation strategies from data. In: Proceedings of the IEEE conference oncomputer vision and pattern recognition. pp. 113–123 (2019)6. Ganaye, P.A., Sdika, M., Triggs, B., Benoit-Cattin, H.: Removing segmentation in-consistencies with semi-supervised non-adjacency constraint. Medical image analysis , 101551 (2019)7. Huang, S.W., Lin, C.T., Chen, S.P., Wu, Y.Y., Hsu, P.H., Lai, S.H.: Auggan: Crossdomain adaptation with gan-based data augmentation. In: Proceedings of theEuropean Conference on Computer Vision (ECCV). pp. 718–731 (2018)8. Isensee, F., Petersen, J., Kohl, S.A., J¨ager, P.F., Maier-Hein, K.H.: nnu-net: Breakingthe spell on successful medical image segmentation. arXiv preprint arXiv:1904.08128(2019)0 Ju Xu, Mengzhang Li, and Zhanxing Zhu9. Kim, S., Kim, I., Lim, S., Baek, W., Kim, C., Cho, H., Yoon, B., Kim, T.: Scalableneural architecture search for 3d medical image segmentation. In: InternationalConference on Medical Image Computing and Computer-Assisted Intervention. pp.220–228. Springer (2019)10. Lim, S., Kim, I., Kim, T., Kim, C., Kim, S.: Fast autoaugment. In: Advances inNeural Information Processing Systems. pp. 6662–6672 (2019)11. Liu, H., Simonyan, K., Yang, Y.: Darts: Differentiable architecture search. arXivpreprint arXiv:1806.09055 (2018)12. Pham, H., Guan, M.Y., Zoph, B., Le, Q.V., Dean, J.: Efficient neural architecturesearch via parameter sharing. arXiv preprint arXiv:1802.03268 (2018)13. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedicalimage segmentation. In: International Conference on Medical image computing andcomputer-assisted intervention. pp. 234–241. Springer (2015)14. Tajbakhsh, N., Shin, J.Y., Gurudu, S.R., Hurst, R.T., Kendall, C.B., Gotway, M.B.,Liang, J.: Convolutional neural networks for medical image analysis: Full trainingor fine tuning? IEEE transactions on medical imaging35