Deep learning based prediction of Alzheimer's disease from magnetic resonance images
Manu Subramoniam, Aparna T. R., Anurenjan P. R., Sreeni K. G
DDeep learning based prediction of Alzheimer’sdisease from magnetic resonance images
Manu Subramoniam, Aparna T. R., Anurenjan P. R., and Sreeni K. G.
Computer Vision Lab, Dept. of ECE,College of Engineering, Thiruvananthapuram [email protected], [email protected], [email protected],[email protected] , Abstract.
Alzheimer’s disease (AD) is an irreversible, progressive neurodegenerative disorder that slowly destroys memory and thinking skillsand eventually, the ability to carry out the simplest tasks. In this paper,a deep neural network based prediction of AD from magnetic resonanceimages (MRI) is proposed. The state of the art image classification net-works like VGG, residual networks (ResNet) etc. with transfer learningshows promising results. Performance of pre-trained versions of thesenetworks are improved by transfer learning. ResNet based architecturewith large number of layers is found to give the best result in terms ofpredicting different stages of the disease. The experiments are conductedon Kaggle dataset.
Keywords:
CNN, ResNet, MRI, Alzheimer’s disease, Transfer learning
Alzheimer’s disease (AD) is a progressive neuro degenerative disease that affectednearly 50 million people worldwide [1]. The disease causes an irreversible damageto the brain that affects cognition, memory and other function and leads to thedeath of the individual from complete brain failure. The economic considerationof the disease is huge and well studied [2].Since the disease in incurable, early diagnosis and medications for delayingthe progression are the only treatment available [3]. Genetic bio-markers likeAmyloid- β precursor protein (A β PP) could be found from blood tests [4], whichmay be used for diagnosing AD. AD results in senile plaques and neurofibrillarytangles throughout the brain which is also considered as a definitive bio-marker.These plaques and tangles tend to shrink the brain volume. This shrinkage isevident in MR images and are used as a criteria for clinical diagnosis [5].The rest of the paper is organized as follows. The related prior work is dis-cussed in Section 2. The details of the proposed work can be found in 3. Thissection also contains details about the dataset used and the details about the var-ious architectures experimented with. The experiments and results are discussedin Section 4. This is followed by a summary of the work in Section 5. a r X i v : . [ ee ss . I V ] J a n Manu et al.
Alzheimer’s disease (AD) is an irreversible, progressive brain disorder with noexisting treatment for curing the disease. Hence a great deal of effort has beenmade to develop strategies for early detection, especially at pre-symptomaticstages of the disease. In particular, advanced neuroimaging techniques, such asmagnetic resonance imaging (MRI) have been used to identify AD-related dis-eases. To predict AD in subjects with mild cognitive impairment (MCI), SimonF. Eskildsen et al. [7] investigated the possibility of using patterns of corticalthickness measurements. Specific patterns of atrophy were identified and fea-tures were selected as regions of interest from these patterns. In Claudia Plantet al. [8] a data mining framework in combination with three different classifiersincluding support vector machine (SVM), Bayes statistics, and voting featureintervals (VFI) were used to derive a quantitative index of pattern matching forthe prediction. In this study, the multivariate methods of pattern matching reacha clinically relevant accuracy for the a priori prediction of the progression fromMCI to AD. To jointly predict multiple variables from multi-modal data Dao-qiang Zhang et al. [9] studies Multimodal multi-task (M3T) learning method.Multi-task feature selection to selects the common subset of relevant featuresfor multiple variables from each modality fuses with multi-modal support vectorto predict multiple (regression and classification) variables. MRI surface mor-phometry mapping is used to evaluate local deformations of the hippocampus,parahippocampal gyrus, and entorhinal cortex to predict conversion from MCIto AD in D.P. Devanand et al. [10]. Amongst the traditional machine learningmethods, SVM is the most popular, which extract high-dimensional, informativefeatures to predict classification models that facilitate the automation of clinicaldiagnosis in Rathore et al. [11]. However, feature extraction and definition relieson manual outlining of brain structures, which is laborious and complex imagepre-processing, which is computationally demanding and time-consuming.To overcome these difficulties, deep learning, an emerging area of machinelearning research that uses raw neuroimaging data to generate features is attract-ing considerable attention in the field of large scale, high-dimensional medicalimaging analysis in Plis et al. [12]. So it requires little or no image pre-processingand can automatically infer an optimal representation of the data from the rawimages without requiring prior feature selection, resulting in a more objectiveand less bias-prone process. Recently, deep learning has been successfully ap-plied to the Alzheimer’s Disease Neuroimaging Initiative (ADNI) [13] dataset toidentify AD patients in Vieira et al. [14]. Deep learning algorithms, without apriori feature selection (considering gray matter [GM] volumes as input) is usedin the prediction of AD development using ADNI structural MRI scans in Suket al. [15]. Convolutional neural networks (CNNs) using 3D T1-weighted imagesfrom the ADNI dataset is used by Silvia Basaia et al., Weiming Lin et al. [16] [17].To avoid using complicated activations, response normalization, or max-pooling,Silvia Basaia et al. [16]used standard convolutional layers with the stride of 2(‘all convolutional network’ )instead of max-pooling layers. Thus there is a re-duction in the number of network parameters. In Weiming Lin et al. [17] MRI lzheimer’s disease prediction 3 images are prepared with age-correction and, local patches are extracted fromthese images. A special extreme learning machine to avoids the random genera-tion of the input weight matrix is chosen for classification with both CNN-basedfeatures and FreeSurfer [18]based features. Karteek Popuri et al. [19] developeda method to quantify the structural patterns from a structural MRI to developa score for similarity to patterns seen in dementia of Alzheimer’s type DAT im-ages. So employed ensemble-learning framework to create an aggregate measureof neurodegeneration in the brain. V.P. Subramanyam Rallabandi et al. [20]usedFreeSurfer analysis to measure regional cortical thickness of both left and righthemispheres. The non-linear support vector machine using a radial basis func-tion kernel is used for classification of different stages of dementia. Manhua Liuet al. [21] proposed a multi-modal deep learning framework based on CNN forjoint automatic hippocampal segmentation and AD classification. A multi-taskdeep CNN model is constructed for jointly learning hippocampal segmentationand classification. To learn these features of patches, a 3D Densely ConnectedConvolutional Networks (3D DenseNet) is constructed. Therefore, deep learningalgorithms are better suited for detecting subtle abnormalities.
In this paper, the axial slices of MRI images is used as the input data for theclassification task. As shown in Fig. 1, the MRI slices are fed to the neuralnetwork which performs feature extraction and classification. The classifier isbasically a CNN model that labels the image into one of the four classes -Non-Demented, Very Mild Demented, Mild Demented and Moderately Demented. Apre-trained ResNet-101 model have been used for the classifier CNN block. Thedetails of ResNet-101 architecture is discussed in Section 3.1. This gives mostaccurate classifications for the application.
In the architecture, we make use of Residual Neural Networks (ResNet-101) forthe classification purposes. ResNet architecture is created by stacking up residualblocks, where each residual block consists of 3 layers - 1 ×
1, 3 × ×
1. Thisis referred to as the bottleneck building block [6]. The 1 × × Manu et al.
Fig. 1.
Block diagram illustration of the proposed method of Alzheimer’s disease clas-sification from brain MRI image slices.layer name output size layerconv1 112 ×
112 7 ×
7, 64 stride 23 × × × , × , × , × × × , × , × , × × × , × , × , × × × , × , × , × × Table 1.
The details of ResNet-101 architecture used in the proposed method showingthe building blocks (in brackets), along with number of blocks stacked.
DNN architecture. With 2 units of the convolutional layer of 3x3 filter with astride 1 and 1 unit of softmax layer of 2x2 filter with a stride 2, which is shownin Table 2.
The data consists of 6400 magnetic resonance images collected and released aspart of a Kaggle competition [22]. MR images are categorized as non-demented, lzheimer’s disease prediction 5
Fig. 2.
Bottleneck building blocks of different layers for ResNet-101 architecture. A.conv2 x layer B. conv3 x layer C. conv4 x layer D. conv5 x layer. very mildly demented, mildly demented and moderately demented based on thelevel of neurological degeneration. The complete dataset is divided between anon overlapping train set and test set. Training set has 5121 images and thetest set consists of 1279 images. Some classes like moderately demented is underrepresented. In order to balance the class variability, data augmentation is usedduring training.
For the experiment setup, Google Colab is used. GPU hardware accelerator isused during the training process. The Kaggle AD dataset with 5121 data samplesfor training and 1279 test samples for testing. The samples were spread across 4
Manu et al. classes - Non-Demented Very Mild Demented, Mild Demented and ModeratelyDemented. Fastai [23] is used for programming the network. In order to balancethe data set, data augmentation is used.
Sl. No. Architecture Accuracy in %1 Vanilla DNN 95.312 CNN-DNN 95.32
Table 2.
Performance Comparison of different neural networks initialized with randomweights
Class Model Accuracy Precision Recall F1–scoreMild Demented VGG 16 98.63 0.93 0.98 0.95VGG 19 94.04 0.72 0.87 0.79Resnet 18 98.54 0.98 0.93 0.95Resnet 34 94.63 0.75 0.89 0.81Resnet 50 96.29 0.81 0.94 0.87
Resnet 101 99.51 0.99 0.98 0.98
Moderate Demented VGG 16 100 1.0 1.0 1.0VGG 19 100 1.0 1.0 1.0Resnet 18 100 1.0 1.0 1.0Resnet 34 99.80 1.0 0.80 0.89Resnet 50 100 1.0 1.0 1.0
Resnet 101 100 1.0 1.0 1.0
Non Demented VGG 16 96.78 0.97 0.96 0.97VGG 19 87.60 0.90 0.85 0.87Resnet 18 97.85 0.96 0.99 0.98Resnet 34 89.16 0.89 0.89 0.89Resnet 50 92.29 0.93 0.91 0.92
Resnet 101 99.61 0.99 1.0 1.0
Very Mild Demented VGG 16 96.39 0.95 0.95 0.95VGG 19 85.35 0.80 0.79 0.80Resnet 18 97.75 0.98 0.96 0.97Resnet 34 86.91 0.85 0.80 0.82Resnet 50 91.89 0.90 0.88 0.89
Resnet 101 99.71 1.0 0.99 1.0
Table 3.
Performance Comparison of the proposed classification method (based onResnet-101) with other existing algorithms. Comparison is done based on the param-eters accuracy, precision, recall and F1-score.
Table 2 shows the accuracy values for different neural networks like simplevanilla DNN, CNN etc. Vanilla DNN consists of 3-hidden layers. Relu is used asthe activation function. CNN-DNN has initial 3 CNN layers followed by 3 DNN lzheimer’s disease prediction 7 layers. From the accuracy values it can be seen that accuracy improves as wemove from simple DNN to CNN-DNN based network.Inspired by the fact that a simple CNN-DNN model worked well on the testdata, we tried with the classic VGG-16 [24] network. Here a pretrained VGG-16network is taken and the last layer is replaced with 3 DNN layers. Retraining thenetwork with the Kaggle data for a few epochs is done. Results listed in table3 shows that VGG-16 with transfer learning helps to improve the prediction.We also tried this transer learning approach with classic networks like VGG-19,Resnet-18 [25], Resnet-34, Resnet-50 and Resnet-101. It can be seen that amongthe VGG architectures, the VGG-16 is better performing than VGG-19. Amongthe residual neural network architectures, the accuracy increases from Resnet-18to Resnet-101. For all methods, the number of epochs is limited to 75.The detailed experimental results are shown in Table 3. The parametersaccuracy, precision, recall, and F1-score are used for performance evaluation.Accuracy is the fraction of total samples that were classified by the classifier
Accuracy = T P + T NT P + T N + F P + F N whereTP refers to the number of predictions where the classifier correctly predicts thepositive class as positiveTN refers to the number of predictions where the classifier correctly predicts thenegative class as negativeFP refers to the number of predictions where the classifier incorrectly predictsthe negative class as positiveFN refers to the number of predictions where the classifier incorrectly predictsthe positive class as negativePrecision refers to the fraction of predictions as a positive class were actuallypositive
P recision = T PT P + F P
Recall refers to what fraction of all positive samples were correctly predictedas positive by the classifier
Accuracy = T PT P + F N
F1-score is given by F − score = 2 T P T P + F P + F N
This paper proposes a deep neural network based classification of MRI data. Weuse pretrained networks like VGGnet and ResNet, and retraining is done with
Manu et al. the in domain data. Experiments done with Kaggle data shows promising results.ResNet-101 architecture beats all others in terms of all evaluation metrics. Futurework should focus on experiments with clinically validated datasets like ADNI,OASIS etc. Also the possibility of exploiting multi modal cues like PET scans,blood test evaluations, MMSE scores etc could be pursued.