Brain Age Estimation Using LSTM on Children's Brain MRI
Sheng He, Randy L. Gollub, Shawn N. Murphy, Juan David Perez, Sanjay Prabhu, Rudolph Pienaar, Richard L. Robertson, P. Ellen Grant, Yangming Ou
BBRAIN AGE ESTIMATION USING LSTM ON CHILDREN’S BRAIN MRI
Sheng He (cid:63)
Randy L. Gollub † Shawn N. Murphy † Juan David Perez (cid:63)
Sanjay Prabhu (cid:63)
Rudolph Pienaar (cid:63)
Richard L. Robertson (cid:63)
P. Ellen Grant (cid:63)
Yangming Ou (cid:63)(cid:63)
Boston Children’s Hospital, Harvard Medical School, Boston , USA † Massachusetts General Hospital, Harvard Medical School, Boston, USA
ABSTRACT
Brain age prediction based on children’s brain MRI is animportant biomarker for brain health and brain developmentanalysis. In this paper, we consider the 3D brain MRI volumeas a sequence of 2D images and propose a new framework us-ing the recurrent neural network for brain age estimation. Theproposed method is named as 2D-ResNet18+Long short-termmemory (LSTM), which consists of four parts: 2D ResNet18for feature extraction on 2D images, a pooling layer for fea-ture reduction over the sequences, an LSTM layer, and afinal regression layer. We apply the proposed method on apublic multisite NIH-PD dataset and evaluate generalizationon a second multisite dataset, which shows that the proposed2D-ResNet18+LSTM method provides better results thantraditional 3D based neural network for brain age estimation.
Index Terms — MRI, Age Prediction, ResNet, LSTM
1. INTRODUCTION
Predicting brain age from brain MRI is becoming an impor-tant biomarker for brain health and brain development analy-sis [1, 2]. The difference between the predicted and chrono-logical age can be used to predict neurocognitive disorders [3]or brain anomaly caused by disease [2].Inspired by deep learning, 3D-Convolutional Neural Net-work (3D-CNN) has been used to predict brain ages froma 3D brain MRI [4]. However, 3D-CNN requires extensivecomputations and more memory than 2D-CNN and it is hardto parallelize computations. A 2017 study used 3D-CNNon MRI to predict brain ages of N=2001 subjects aged from18-90 years and reported an average prediction error of 4.16years [2]. We are particularly interested in 0-20 years espe-cially 0-6 years of age, where data is hard to get, the samplesize is smaller, the brain is rapidly developing, and the inter-subject variability is bigger. Because of all these, 3D-CNNfaces an additional risk for over-fitting.To solve these problems, we propose a novel frameworkwhich considers the 3D brain MRI volume as a sequence of2D images and uses recurrent neural network (such as thetypical Long Short-Term Memory or LSTM [5]) to capture
Multisite: 12 sites
Age: 0-22Number: 12121.5T, T1-WeightedMultisite: 2 sitesAge: 0-6Number:4283.0T, T1-weighted T r a i n T e s t G e n e r a li z e DistributionDistribution
Fig. 1 : The proposed framework for age estimation. UsingNIH-PD as the Discovery Cohort, we train two different mod-els (the typical 3D-CNN and the proposed 2D-CNN+LSTM)on the training set (80% of NIH-PD data) and test on the test-ing set (20% of NIH-PD data). Then, using MGHBCH asthe Replication Cohort, we quantify the generalization abilityacross datasets.age information over the sequence of 2D slices. A similar ap-proach (treating a 3D MRI as a series/sequence of 2D slices)was recently developed for voxel-wise abnormality detec-tion [6], but we face a different task – patient-level predictionrather than voxel-wise classification, and hence a differentformulation. In our framework, the feature representationsof 2D slices are extracted by 2D-CNN which is easy to beparallelized and they are fed into LSTM to capture the globalage information, which we assume is less likely to be trappedat local minima and less likely to over-fit in a small samplesize in this age range.In most brain age estimation studies, all available dataare split into training and testing sets. Some of the trainingand testing subjects share imaging sites/protocols. So, al-though the exact testing subjects are unseen during training,the imaging sites/protocols may have already been exposed totraining. In contrast, we use two multisite datasets: the NIH-Pediatric Data (NIH-PD) [7] (GE/Siemens 1.5T scanner, a r X i v : . [ c s . C V ] F e b D R e s N e t P oo li n g L S T M R e g r e ss o r Age
Fig. 2 : Illustration of the 2D-ResNet with LSTM frameworkfor age prediction. The 3D MRI is considered as a sequenceof 2D images. The 2D-Resnet is used to extract features andthen followed by a pooling operation. LSTM is used to cap-ture the contextual information over the sequence and the finalregressor is applied to estimate the age.12 sites, spin-echo T1-weighted MRI), and the MGHBCHdataset (Massachusetts General and Boston Children’s Hos-pitals, Siemens 3.0T Trio scanners, T1-weighted MPRAGEsequence). The two datasets have very different imagingsites/protocols, and by training/testing on NIH-PD (the Dis-covery Cohort) and replicating on MGHBCH data (the Repli-cation Cohort, Fig 1), we fully evaluate 3D-CNN versusour proposed algorithm for the generalization ability in truly”unseen” dataset – not only testing subjects unseen but theirimaging protocols/sites unseen.
2. METHODS
In this section, we describe the details of the proposedmethod, named as 2D-ResNet18+LSTM, which consists offour main parts (shown in Fig. 2): feature extraction by 2Dresidual network with 18 layers, a pooling for feature reduc-tion, LSTM for context modeling over the sequence and afinal regressor for age predication.
Feature extraction.
Given a sequence of 2D slices of a3D brain MRI volume, X = { x , x , · · · , x n } (where n is thenumber of slices), 2D-ResNet [8] is used to extract featurerepresentation. In this paper, we use ResNet18 without thefully-connected layer as the backbone to extract feature f t foreach 2D image x t : f t = ResNet (Θ , x t ) (1)where f t is the feature representation of x t and Θ is the set ofparameters for ResNet18. We use ReLU after each convolu-tional layer as the activation function. Pooling.
Looking at 2D slices one by one at a constantpace may not be an optimal way to gather age information.It is likely that pooling information from several adjacent 2Dslices can smooth out noises in single slices and can betterreveal true age information. This motivates us to combinefeatures across adjacent 2D slices by a pooling operation. The _id=1152 𝜎 𝜎 𝑡𝑎𝑛ℎ 𝑡𝑎𝑛ℎ𝜎 𝑐 𝑡−1 𝑐 𝑡 ℎ 𝑡−1 ℎ 𝑡 𝑓 𝑡 𝑖 𝑡 𝑔 𝑡 𝑜 𝑡 𝑝 𝑡 Fig. 3 : The structure of LSTM cell, which consists of a cellmodule c t , a hidden state h t , a forget gate f t , a input gate i t and a output gate o t .most common pooling method is the average pooling, whichis defined as: p ( m ) avg = 1 k (cid:88) j = k ∗ m,...,k ∗ m + k f j (2)where k is the kernel size of the pooling. Finally, the sequenceis reduced to P = { p , p , · · · , p m } where m = n/k . In theexperiments, we set k = 3 . LSTM.
As mentioned above, each 3D brain MRI can beconsidered as a sequence of 2D slices, which, after pooling,are in m stacks of 2D slices represented by the feature vector P = { p , p , · · · , p m } . This formulation motivates us to userecurrent neural networks, such as LSTM [5], to capture theage information. The LSTM cell (shown in Fig. 3) is definedas: i t = σ ( W ix · p t + W ih · h t − + b i ) f t = σ ( W fx · p t + W fh · h t − + b f ) o t = σ ( W ox · p t + W oh · h t − + b o ) g t = ϕ ( W gx · p t + W gh · h t − + b g ) c t = f t (cid:12) c t − + i t (cid:12) g t h t = o t (cid:12) ϕ ( c t ) (3)where σ ( · ) , ϕ ( · ) are the sigmoid and tangent tanh ( · ) func-tion, respectively. (cid:12) denotes an element-wise product. W ∗ x , W ∗ h and b ∗ are parameters which are learned during training.We use the Bidirectional LSTM with the hidden dimensionsize of 64, to balance between the performance and the com-puting cost. Age regressor.
All the outputs of LSTM cells are con-catenated as the feature vector (cid:126)H = [ h , h , ..., h t ] and a fullyconnected layer is used as the regressor for age prediction. Instance Normalization (IN).
Use 2D/3D batch normal-ization when the batch size is small (such as 1) may quicklyexplode the computer memory. Therefore, in this paper, weresort to instance normalization [9], which is defined as: y = x − E [ x ] (cid:112) Var [ x ] + (cid:15) (4)where the input x is a 2D/3D feature map. E and Var are themean and variation, respectively. Instance Normalization isapplied after each 2D/3D convolutional layer. . EXPERIMENTAL RESULTS We use the NIH-PD dataset as the Discovery Cohort. Wepartitioned the 1212 subjects (0-22 years of age at the timeof MRI scans) in the NIH-PD data into training (80%) andtesting sets (20%) to build the model and evaluate the modelperformance during the Discovery phase. Then, we use theMGHBCH dataset, which contains 428 normal-developingsubjects 0-6 years of age, as a totally ”unseen” ReplicationCohort.Our second setting to test the generalization ability of ageprediction in the MGHBCH Replication Cohort (0-6 years ofage) is to compare the accuracies by training on NIH-PD sub-jects 0-22 years and by training on NIH-PD subjects 0-6 yearsof age. The age distribution of each dataset is shown in Fig. 1.T1-weighted MPRAGE MRI of every subject is registeredto the SRI atlas (image size [128,160,112] and voxel size × × mm [10]). We compared using the SRI atlas (adults)and a 2-year-old pediatric atlas and found little differences inthe final accuracies. We resize 2D images on the sequence to50 ×
50 voxels to reduce the computation time and the use ofmemory. The voxel values of each subject are normalized bysubtracting the mean value of the subject’s image and dividingby the variance.We use the corresponding 3D-ResNet18 as the baselinefor comparison, which consists of 3D-convolutional layerswith kernel size of 3 × ×
3, global average pooling and afully connected layer for age prediction. All networks, in-cluding 2D-ResNet18+LSTM and 3D-ResNet18 are trainedwith Mean Absolute Error (MAE) loss. The Adam optimiza-tion method is used with an initial learning rate of 0.0001,decreasing by half at every 15 epochs. The network is trainedwith 60 epochs. The batch size is set to 1.The model is evaluated using two measurement metrics:Mean Absolute Error (MAE) and Cumulative Score (CS):
M AE = 1 N N (cid:88) i =1 | y i − ˆ y i | CS ( α ) = 1 N N (cid:88) i =1 (cid:110) | y i − ˆ y i | ≤ α (cid:111) × (5)where y i is the ground-truth age, ˆ y i is the estimated age, α is the error level and N is the number of test samples. TheCS ( α ) score is the percentage of subjects whose errors of ageprediction are smaller than or equal to α . A smaller M AE and a higher CS ( α ) indicate a more accuracy algorithm. Table 1 shows the MAEs of different methods in differ-ent age groups on the NIH-PD testing set. The proposed2D-ResNet18+LSTM provides lower MAEs in 6-10, 11-15and 16-22 years, and gives a lower average MAEs than 3D-ResNet18. Fig. 4 shows the CS scores as a function of error
Table 1 : Accuracy in the Discovery Cohort. The MAE (inyears) for brain age estimation on the NIH-PD testing set.Numbers in bold are more accurate results.
Method Age group Average0-5 6-10 11-15 16-223D-ResNet18 . . Error level ( α in years) P e r ce n t a g e ( % ) Fig. 4 : Cumulative score (CS) on the error levels from 0 to 2years of different methods on the Discovery Cohort (NIH-PDtesting set). A higher CS is better.
Table 2 : Accuracy in the Replication Cohort, training on Dis-covery Cohort 0-22 years of age. The MAE (in years) of dif-ferent methods for brain age estimation trained on the NIH-PD training set and replicated on the truly ”unseen” MGH-BCH dataset for examining the generalization ability. Num-bers in bold are more accurate results.
Method Age group Average0-1 1-2 2-3 3-4 4-5 5-63D-ResNet18 2.57 2.43 2.80 2.64 2.85 2.75 2.642D-ResNet18+LSTM levels α . The proposed 2D-ResNet18+LSTM provides betterresults than 3D-ResNet18 in terms of MAE and CS. In this section, we provide the results of age estimation onthe MGHBCH dataset using the model trained on the NIH-PD training set. This is to evaluate the generalization of 2D-ResNet18+LSTM and 3D-ResNet18 methods.Table 2 shows the MAEs of these two methods on theMGHBCH dataset (0-6 years) after training on NIH-PD sub-jects 0-22 year of age. The proposed 2D-ResNet18+LSTMprovides much better results than 3D-ResNet18 on the MGH-BCH dataset. Fig. 5 shows the CS curves of these two meth-ods and the proposed 2D-ResNet18+LSTM provides betterresults than 3D-ResNet18 with all different error levels α .We repeat the above experiment in the Discovery (MGH-BCH) Cohort, but using NIH-PD subjects 0-6 years fortraining. Table 3 shows the results. When the Discov- . . Error level ( α in years) P e r ce n t a g e ( % ) Fig. 5 : Cumulative score (CS) on the error levels from 0 to 2years trained on the Discovery Cohort (NIH-PD training set)and applied to the Replication Cohort. A high CS is better.
Table 3 : Accuracy in the Replication Cohort, training on Dis-covery Cohort 0-6 years of age. The MAE (in years) of differ-ent methods for brain age estimation trained on the NIH-PDtraining set with subjects of age 0-6 years and replicated onthe truly ”unseen” MGHBCH dataset for examining the gen-eralization. Numbers in bold are more accurate results.
Method Age group Average0-1 1-2 2-3 3-4 4-5 5-63D-ResNet18 0.66 ery and Replicate Cohorts share the same age range, theaverage MAEs of 2D-ResNet18+LSTM and 3D-ResNet18are approximately the same, but the computation time of2D-ResNet18+LSTM is shorter than 3D-ResNet18 (3 hoursversus 7 hours on an NVIDIA K80 GPU).
4. CONCLUSION
Recent age prediction is mainly on 18-90 years of age and us-ing 2000-3000 brain MRIs. We focus on 0-22 especially 0-6years where the sample size is smaller. In such circumstances,we shown that the proposed 2D-ResNet18+LSTM methodprovides comparative or better results than 3D-ResNet18.While existing age prediction studies split the DiscoveryCohort into training and testing, we further test the general-ization ability in a truly unseen Replication Cohort (subjectsunseen, imaging sites/protocols unseen). We show that theproposed 2D-ResNet18+LSTM method generalizes better tothe Replication dataset when the age ranges of Discovery andReplication Cohorts are different (Table 2). This opens a win-dow to train on a large-scale Discovery Cohort with a widerage range (such as 0-20 years) and to adapt to a ReplicationCohort where the age range of interest is narrower (0-2 or 0-6years) and the sample size is smaller.When the age ranges of Discovery and Replication Co-horts are the same, 2D-ResNet18+LSTM has a faster compu-tation, and offers a higher accuracy in both ends of the age range (Table 3) while 3D-ResNet18 has a higher accuracy inthe middle of the age range. This suggests us to explore waysto combine the proposed 2D-ResNet18+LSTM and the 3D-ResNet18 methods to achieve high accuracies throughout theage range of interest. Our future work also includes morecomprehensive evaluation in large multisite datasets and ap-plication to probing brain disorders in early life.
5. REFERENCES [1] James H Cole, Stuart J Ritchie, Mark E Bastin,MC Vald´es Hern´andez, S Mu˜noz Maniega, NatalieRoyle, Janie Corley, Alison Pattie, Sarah E Harris, QianZhang, et al., “Brain age predicts mortality,”
Molecularpsychiatry , 2017.[2] James H Cole, Rudra PK Poudel, Dimosthenis Tsagkra-soulis, Matthan WA Caan, Claire Steves, Tim D Spector,and Giovanni Montana, “Predicting brain age with deeplearning from raw imaging data results in a reliable andheritable biomarker,”
NeuroImage , vol. 163, pp. 115–124, 2017.[3] Habtamu Minassie Aycheh, Joon-Kyung Seong, Jeong-Hyeon Shin, Duk L Na, Byungkon Kang, Sang WonSeo, and Kyung-Ah Sohn, “Biological brain age pre-diction using cortical thickness data: a large scale co-hort study,”
Frontiers in aging neuroscience , vol. 10,pp. 252, 2018.[4] Masaru Ueda, Koichi Ito, Kai Wu, Kazunori Sato, Ya-suyuki Taki, Hiroshi Fukuda, and Takafumi Aoki, “Anage estimation method using 3D-CNN from brain MRIimages,” in
ISBI , 2019, pp. 380–383.[5] Sepp Hochreiter and J¨urgen Schmidhuber, “Long short-term memory,”
Neural computation , vol. 9, no. 8, pp.1735–1780, 1997.[6] Qikui Zhu, Bo Du, Baris Turkbey, Peter Choyke, andPingkun Yan, “Exploiting interslice correlation for mriprostate image segmentation, from recursive neural net-works aspect,”
Complexity, Article 4185279 , 2018.[7] Alan C Evans, Brain Development Cooperative Group,et al., “The NIH MRI study of normal brain develop-ment,”
Neuroimage , vol. 30, no. 1, pp. 184–202, 2006.[8] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and JianSun, “Deep residual learning for image recognition,” in
Proceedings of the IEEE conference on computer visionand pattern recognition , 2016, pp. 770–778.[9] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky,“Instance normalization: The missing ingredient for faststylization,” arXiv preprint arXiv:1607.08022 , 2016.[10] Torsten Rohlfing, Natalie M Zahr, Edith V Sullivan, andAdolf Pfefferbaum, “The sri24 multichannel atlas ofnormal adult human brain structure,”