Improved Brain Age Estimation with Slice-based Set Networks
Umang Gupta, Pradeep K. Lam, Greg Ver Steeg, Paul M. Thompson
IIMPROVED BRAIN AGE ESTIMATION WITH SLICE-BASED SET NETWORKS
Umang Gupta (cid:63)
Pradeep K. Lam † Greg Ver Steeg (cid:63)
Paul M. Thompson † (cid:63) Information Sciences Institute, University of Southern California † Imaging Genetics Center, Mark and Mary Stevens Institute for Neuroimaging and Informatics,Keck School of Medicine, University of Southern California
ABSTRACT
Deep Learning for neuroimaging data is a promising but chal-lenging direction. The high dimensionality of 3D MRI scansmakes this endeavor compute and data-intensive. Most con-ventional 3D neuroimaging methods use 3D-CNN-based ar-chitectures with a large number of parameters and requiremore time and data to train. Recently, 2D-slice-based modelshave received increasing attention as they have fewer param-eters and may require fewer samples to achieve comparableperformance. In this paper, we propose a new architecturefor BrainAGE prediction. The proposed architecture worksby encoding each 2D slice in an MRI with a deep 2D-CNNmodel. Next, it combines the information from these 2D-sliceencodings using set networks or permutation invariant layers.Experiments on the BrainAGE prediction problem, using theUK Biobank dataset, showed that the model with the permu-tation invariant layers trains faster and provides better predic-tions compared to other state-of-the-art approaches.
Index Terms — MRI, deep learning, brain age, machinelearning, neuroimaging
1. INTRODUCTION
In this work, we focus on the problem of predicting brain agefrom 3D MRI scans.
Brain Age Gap Estimation (BrainAGE)from structural MRI acts as an important biomarker for as-sessing and diagnosing an individual’s risk of neurologicaldiseases. A brain age prediction model is estimated by train-ing on a large dataset of MRIs of healthy subjects to predicttheir chronological age. Deviation of the true age from thisMRI derived age is a useful biomarker for various neurologi-cal diseases [1].Deep Convolutional Neural Networks (CNN) have showntremendous improvements over traditional computer visionapproaches. Directly extending computer vision successesto 3D neuroimaging data by substituting 2D-CNN with3D-CNN is non-trivial and has received considerable at-tention [2, 3]. A 3D-CNN has more parameters than its 2Dcounterpart and therefore requires more data to train robustmodels. However, the number of samples for any 3D neu-roimaging problem is often less, with the largest datasets
Epoch M AE MAE on validation set during training
Fig. 1 . Smoothed training curves (original data is shownwith less opacity and same color). Our proposed architectureyields better predictive performance and trains faster than theother baseline architectures (see Sec. 4.1 for details).typically having tens of thousands rather than millions ofsamples. Conventional approaches that have applied deeplearning to neuroimaging have been focused on designingbetter data augmentation techniques and designing better3D-CNN models [4, 5, 6]. Other approaches have tried touse transfer learning [7, 8]; however, the scarcity of general-purpose pretrained 3D-CNN models has limited these meth-ods to adapt 2D-CNN models, which is not ideal.Recently, [9] proposed a model for the BrainAGE prob-lem, which encodes slices along the sagittal axis using a 2D-CNN encoder and then processes this ordered sequence ofslices using a recurrent model (the long short-term model,LSTM). Their approach outperforms the 3D counterpart whentrained from scratch. However, it relies on fixing the orderingof the slices, and the optimal ordering is unclear. Moreover,slices that occur earlier in the sequence may not be able toinfluence the predictions much.Our approach to this problem alleviates the requirementof specifying the ordering over slices by using recently pro-posed set networks [10, 11]. Similar to [9], we employ a 2Dencoder that encodes each slice. However, to combine the in-formation in different slices, we consider the slices as a setand use a permutation invariant operation over this set. Theoutput of the permutation invariant operation does not changeif the input elements are permuted. Thus, making the outputindependent of the ordering in which the slices are processed.We evaluate various permutation invariant operations, namely a r X i v : . [ ee ss . I V ] F e b alue, v2D encoderencodings linear layer query, qlinear layer key, k dot + softmax attention weights, wdot feedforward layersAggregation block Fig. 2 . Model architecture with attention-based aggregation.Gray blocks are trainable parameters, whereas yellow blocksare operations only. Each scan is considered a set of slicesand transformed to a set of encodings via a 2D-CNN encoder.Attention scores are computed using these encodings and thetrainable query vector. Finally, the aggregated embedding ispassed through the feed-forward layers to predict the age.— mean, max, and a general weighted average operation im-plemented via attention. We evaluated the proposed modelson the BrainAGE prediction problem in the publicly avail-able UK Biobank dataset [12] and show that our model trainsfaster and provides better prediction than the other competi-tive baselines (see Fig. 1).
2. MODEL
Our model takes a 3D scan as input and encodes each sliceusing a 2D-CNN encoder. Next, it combines the slice en-codings using an aggregation module (described in Sec. 2.2),resulting in a single embedding for the scan. Finally, we passthis embedding through the feed-forward layers to predict thebrain age. The model is trained end-to-end using MSE loss.A high-level overview of our architecture is shown in Fig. 2.
The 2D-CNN encoder takes a single 2D slice as input andoutputs a d -dimensional embedding for each slice. We usethe same 2D encoder architecture as [9] - the only differenceis that number of filters in the last layer is d , which is decidedby the dimension of the output of the aggregation module,described next. Once we have the encoding for each slice, we need to com-bine information across this set of slices. To this end, weuse permutation invariant layers as the aggregation module;this makes the aggregation module’s output independent ofthe slice order. The two most common permutation invariantoperations are mean and max over the set [11]; that is, wecompute element-wise mean and max of all the slice encod-ings, respectively. Further, the mean operation can be gener-alized by using a weighted average of the encodings, whereweights are computed using attention [10]. The attention isimplemented as follows. Let q ∈ R d (cid:48) × be a trainable query vector and r i ∈ R d × be the encoding of the i th slice. Wefirst transform r i to key and value vector — k i ∈ R d (cid:48) × and v i ∈ R d (cid:48)(cid:48) × respectively via appropriate linear layers.Next, we compute the attention scores for each encoding. Ifthe number of slices are p , and K ∈ R d (cid:48) × p be the matrix ofall key vectors. The importance weights or attention is com-puted as w = softmax (cid:110) q T K/ √ d (cid:48) (cid:111) . Finally, we computethe weighted average of value vectors as the embedding forthe scan as (cid:80) pi w i v i . Multiple attention heads can be used sothat the model can focus on different slices for prediction. Toachieve m heads, we use q ∈ R d (cid:48) × m . We compute ( (cid:80) pi v i ) /p and max pi v i , ignoring the query and key vectors when usingmean and max operation.For ease of reference, we name the models using mean,max, and attention operation as , ,and , respectively. In our experiments, wevary d = d (cid:48) = d (cid:48)(cid:48) ∈ { , } , and vary m ∈ { , , , } .However, we found that the results are not very sensitive to d or m . Therefore, we fix d = 32 , m = 1 for all the models. Weuse one hidden layer network with 64 activations as the feed-forward layer. We used slices along the sagittal axis; however,we find that the results do not change much if we use slicesalong the coronal or axial direction, as discussed in Sec. 4.4.
3. EXPERIMENT SETUP3.1. Dataset
We use the same dataset and set-up as [9]. In particular, asubset of 10,446 subjects - with no psychiatric diagnosis asdefined by ICD-10 criteria - was selected from 16,356 sub-jects in the UK Biobank dataset [12]. We used the same pre-processing, and the final dimension of the images is × × . The training, test, and validation set sizes were7,312, 940, and 2,194, with a mean chronological age andstandard deviation of 62.6 and 7.4 years. Most conventional deep learning approaches forBrainAGE estimation use 3D-CNNs [4, 5]. They adapt con-ventional 2D-CNN architectures to work on 3D images byreplacing 2D operations with 3D operations. For instance,2D convolutions are replaced with 3D convolutions, 2D max-pooling is replaced with 3D max-pooling, and so on. Weadapt the 2D encoder mentioned in Fig. 2 and Sec. 2 to workwith 3D images. Instead of using an aggregation module,we pass the encodings through another 3D convolution toproduce a single node output. This architecture is the sameas [4] but uses instance-norm instead of batch-norm after eachconvolutional layer due to the instability of batch-norm withsmaller batch sizes.
We also compare our approach to the re-cently proposed approach of [9]. Similar toethod MAE Parameters3D-CNN 3.017 2,948,8012D-slice-RNN 3.002 1,070,4032D-slice-attention ( m =1 , d =32) . MAE on Test set (lower is better).our approach, they compute 2D encodings by taking slicesalong the sagittal axis. However, the sequence of encoding isaggregated by an LSTM. Their approach used fewer param-eters and has been shown to outperform the corresponding3D-CNN architecture. We use the same feature embeddingsize (2) and hidden state size of LSTM (128) and apply gra-dient norm clipping during the training with value 1 as usedin their paper. Each model is trained for 100 epochs with the Adam opti-mizer, a weight decay of − , a learning rate of − , and abatch size of 8 with MSE loss. The last layer’s bias is initial-ized with the mean age of the training set (62.68 years). Wepick the best model by monitoring the performance on the val-idation set and report the mean absolute error (MAE) betweenthe predicted and the true age on the test set. The code for allthe experiments is publicly available at https:// git.io/ JtazG .
4. RESULTS4.1. Faster Training & Better Predictions
Table 1 summarizes the mean absolute error (MAE) and thenumber of parameters for all the methods. Our approach us-ing mean and attention operations outperforms all the othermethods while also being parameter efficient. It trains fasterthan and as shown in Fig. 1. suffers from the issue of having to process all theslices sequentially. Consequently, if an important slice is to-wards the beginning of the sequence, it may take significanttraining steps to learn to propagate that information into theembedding. During the initial training phase of , the loss did not decrease (see Fig. 1), which supportsthis hypothesis. Instead, each slice encoding directly con-tributes to the embedding when using a permutation invariantlayer, therefore receiving better gradient updates. We see thatthe max operation performs poorly. We attribute this to thefollowing — 1) only weights corresponding to the max neu-ron are updated during each iteration, 2) the max operationis susceptible to outliers, so a slight change in input mightlead to a large change in output. This effect can be seen inTable 2 and 3, where missing slices lead to less sensitivity to Method k =1 k =2 k =4 k =5 k =10 Table 2 . Test MAE when all but every k th slice is dropped,* indicates evaluation with data imputation. k =1 means nomissing slices.Method Table 3 . Test MAE when slices are missing at random (av-eraged over 10 evaluation runs), * indicates evaluation withdata imputation. Columns indicate % of slices available.outliers. Thus, the performance does not degrade or even im-proves slightly when using max operation. Attention general-izes the mean operation in theory; however, its performanceis slightly worse or the same as the mean operation. In practice, some clinical centers may use a sparser MRI ac-quisition (e.g., slices 5-mm apart), or scans may lack someslices due to artifacts or due to an incomplete field of view thatfails to cover the entire brain. It is also of interest whether alimited slice set is sufficient, allowing reduced file transfer orfaster processing, and understanding if there is redundancy inthe training data. We simulate these situations in two ways —1) we remove all but every k th slice from the scans (Table 2);2) we keep a fixed percentage of slices chosen at random fromthe scans (Table 3). We do this for each scan in the test set andevaluate the models trained on complete data, i.e., withoutmissing slices. As 3D-CNN cannot be used without imputingthe missing slices, we impute data by substituting the miss-ing slices with the nearest available slice. Our method doesnot depend on the ordering of the slices; therefore, it does notrequire imputation and can handle missing slices gracefully.Our approach considers slices as a set rather than an orderedsequence. Therefore, It can tolerate missing elements in theset, and performance is only slightly worse than when all thedata is present. Due to ’s dependence on theordering, it performs better with data imputation.ethod n =1000 n =2500 n =5000 Table 4 . MAE on test set when trained with n samples.Axis → Sagittal Coronal Axial2D-slice-attention 2.855 2.948 3.1022D-slice-RNN 3.002 3.266* 3.107
Table 5 . Effect of using slices along different axis, * indicatesmodel trained with learning rate − Table 4 summarizes the results when fewer training samplesare available. We use a subset of n samples from the trainingset and keep the number of updates the same as training withall the data. The performance gap between our model and thebaselines is further enhanced when fewer training samples areavailable, suggesting the proposed model’s usefulness. We summarize the results of using slices along the axial andcoronal axis in Table 5 with the and model. The performance was only slightly worsethan using slices along the sagittal axis. wasunable to learn when sliced along the coronal axis, and wefound that it was necessary to reduce the learning rate. There-fore, we used a learning rate of − .
5. DISCUSSION
In this paper, we proposed a new 2D-slice-based architecturefor BrainAGE estimation. By considering the slices as a setand using permutation invariant layers instead of LSTM (asin [9]), our model combines information across slices moreefficiently. It converges faster and outperforms other deeplearning architectures when trained from scratch. By avoidingdependence on slice order, the proposed model is also tolerantto missing slices in the scans.Other approaches have also employed 2D-slice-basedCNN models for neuroimaging data. [7] used only slices withthe highest entropy to learn the model; even so, such criteriamay lead to poor outcomes, as a noisy slice can have highentropy but less information. [13] evaluated the possibility ofusing pretrained 2D-residual networks for Alzheimer’s dis-ease diagnosis. [14] considers each slice as an independentsample for Alzheimer’s disease diagnosis, which increasesthe number of samples available for training. 2D-CNNs were often chosen, as pretrained networks are widely available for2D (but not 3D) images. [8] used transfer learning with pre-trained ImageNet models to predict brain age; they considereach slice as an independent sample and output the medianas the prediction. When trained from scratch, this procedureyields a very high MAE (around 3.87) than the models wehave discussed. 2D-slice-based approaches can be more effi-cient to train as they share parameters across the slices leadingto fewer parameters in the model. Most of these approacheseither use only a few of the slices selected via pre-processingor consider each slice as an independent sample, combiningthe results via ensembling. Thus, these models cannot betrained in an end-to-end fashion. Our approach combinesinformation across all the slices using permutation invariantoperations, enabling model training in an end-to-end fashionand learning to ignore any slices that are not beneficial for thetask. It is also possible to leverage transfer learning with ourmodel. For instance, one may use a pretrained 2D encoder,and the rest of the model can be trained from scratch. An ex-tensive comparison with transfer learning and other classicalapproaches is left as future work.In Sec. 4.3, we found that performance gaps are enhancedwhen fewer samples are available. This gap can be attributedto encoding slices with parameter efficient 2D-CNN ratherthan 3D-CNN. Even though our model encodes slices with a2D-CNN, it is a 3D architecture when looked at end-to-end.Thus, it may provide the same expressiveness with fewer pa-rameters. We believe that other neuroimaging prediction tasksmay also benefit from this architecture.Our proposed architecture provides improved brain ageprediction for healthy subjects. To further validate the out-puts as a biomarker of brain aging or neurological diagno-sis [15, 16], we plan to further evaluate the model on (1) out-of-distribution samples including people with neurodegener-ative diseases, and (2) data from different scanners. We alsoplan to test if the brain-age delta produced by our model isassociated with health-related outcomes and future decline.As this work’s focus is proposing new architecture, weconsidered a simplified scenario, testing the methods on onedataset without considering the scanner’s effects, site, andother biases. Some recent works tackled these problems byproposing novel training objectives. For instance, [17] usesadversarial learning to learn a model for brain age prediction,focusing on generalization across three cohorts with differ-ent scanning protocols and age distributions. [18] proposedan unsupervised method to adjust for site effects. [19] usedattention-based models for domain adaptation, which identifythe most important brain regions to focus on. In contrast, ourpermutation invariant attention layer, inspired by [10], worksby identifying the most important slice. The proposed ar-chitecture is compatible with these objectives. Future workshould test how well the model generalizes in datasets withdifferences in scanning protocols and populations. . ACKNOWLEDGMENTS
This research was supported in part by DARPA contractHR0011-2090104, and NIH grants U01AG068057 andRF1AG051710.
7. COMPLIANCE WITH ETHICAL STANDARDS
This is a study of previously collected, anonymized de-identified data available in a public repository. Data accesswas approved under UK Biobank Application Number 11559.
8. REFERENCES [1] Katja Franke and Christian Gaser, “Ten Years ofBrainAGE as a Neuroimaging Biomarker of Brain Ag-ing: What Insights Have We Gained?,”
Frontiers in Neu-rology , vol. 10, pp. 789, 2019.[2] Jens Kleesiek, Gregor Urban, Alexander Hubert, DanielSchwarz, Klaus Maier-Hein, Martin Bendszus, andArmin Biller, “Deep MRI brain extraction: A 3D con-volutional neural network for skull stripping,”
NeuroIm-age , vol. 129, pp. 460–469, 2016.[3] Satya P Singh, Lipo Wang, Sukrit Gupta, Haveesh Goli,Parasuraman Padmanabhan, and Bal´azs Guly´as, “3Ddeep learning on medical images: a review,”
Sensors ,vol. 20, no. 18, pp. 5097, 2020.[4] Han Peng, Weikang Gong, Christian F. Beckmann, An-drea Vedaldi, and Stephen M. Smith, “Accurate brainage prediction with lightweight deep neural networks,”
Medical Image Analysis , vol. 68, pp. 101871, 2021.[5] James H Cole, Rudra P K Poudel, Dimosthenis Tsagkra-soulis, Matthan W A Caan, Claire Steves, Tim D Spec-tor, and Giovanni Montana, “Predicting brain age withdeep learning from raw imaging data results in a reli-able and heritable biomarker,”
NeuroImage , vol. 163,pp. 115–124, 2017.[6] Nicola K Dinsdale, Emma Bluemke, Stephen M Smith,Zobair Arya, Diego Vidaurre, Mark Jenkinson, and AnaI L Namburete, “Learning patterns of the ageing brain inMRI using deep convolutional networks,”
NeuroImage ,vol. 224, pp. 117401, 2021.[7] M. Hon and N. M. Khan, “Towards Alzheimer’s dis-ease classification through transfer learning,” in , 2017, pp. 1166–1169.[8] Vishnu Bashyam et al., “MRI signatures of brain ageand disease over the lifespan based on a deep brain net-work and 14468 individuals worldwide,”
Brain , vol.143, no. 7, pp. 2312–2324, 06 2020. [9] Pradeep K Lam, Vigneshwaran Santhalingam, ParthSuresh, Rahul Baboota, Alyssa H Zhu, Sophia I Tho-mopoulos, Neda Jahanshad, and Paul M Thompson,“Accurate brain age prediction using recurrent slice-based networks,” bioRxiv , 2020.[10] Juho Lee, Yoonho Lee, Jungtaek Kim, Adam Kosiorek,Seungjin Choi, and Yee Whye Teh, “Set transformer:A framework for attention-based permutation-invariantneural networks,” in
International Conference on Ma-chine Learning . PMLR, 2019, pp. 3744–3753.[11] Manzil Zaheer, Satwik Kottur, Siamak Ravanbhakhsh,Barnab´as P´oczos, Ruslan Salakhutdinov, and Alexan-der J Smola, “Deep Sets,” in
Advances in Neural In-formation Processing Systems , 2017, pp. 3394–3404.[12] Karla L Miller et al., “Multimodal population brainimaging in the UK Biobank prospective epidemiologi-cal study,”
Nature Neuroscience , vol. 19, no. 11, pp.1523–1536, 2016.[13] Aly Valliani and Ameet Soni, “Deep residual nets forimproved Alzheimer’s diagnosis,” in
Proceedings of the8th ACM International Conference on Bioinformatics,Computational Biology, and Health Informatics , 2017,pp. 615–615.[14] Jyoti Islam and Yanqing Zhang, “Brain MRI analysis forAlzheimer’s disease diagnosis using an ensemble sys-tem of deep convolutional neural networks,”
Brain in-formatics , vol. 5, no. 2, pp. 2, 2018.[15] Ellyn R Butler et al., “Statistical Pitfalls in Brain AgeAnalyses,” bioRxiv , 2020.[16] Stephen M Smith, Diego Vidaurre, Fidel Alfaro-Almagro, Thomas E Nichols, and Karla L Miller, “Es-timation of brain age delta from brain imaging,”
Neu-roImage , vol. 200, pp. 528–539, 2019.[17] Nicola K Dinsdale, Mark Jenkinson, and Ana I L Nam-burete, “Unlearning Scanner Bias for MRI Harmonisa-tion,” in
Medical Image Computing and Computer As-sisted Intervention – MICCAI 2020 . 2020, pp. 369–378,Springer International Publishing.[18] Daniel Moyer, Greg Ver Steeg, Chantal MW Tax, andPaul M Thompson, “Scanner invariant representationsfor diffusion MRI harmonization,”
Magnetic Resonancein Medicine , 2020.[19] Hao Guan, Erkun Yang, Pew-Thian Yap, DinggangShen, and Mingxia Liu, “Attention-Guided Deep Do-main Adaptation for Brain Dementia Identification withMulti-site Neuroimaging Data,” in