Cox-nnet v2.0: improved neural-network based survival prediction extended to large-scale EMR dataset
CCox-nnet v2.0: improved neural-network based survival prediction extended to large-scale EMR dataset
Di Wang , Kevin He , Lana X Garmire Department of Biostatistics, University of Michigan, Ann Arbor, MI Department of Computational Medicine and Bioinformatics, University of Michigan, Ann Arbor, MI * Corresponding author: [email protected]
Abstract
Summary:
Cox-nnet is a neural-network based prognosis prediction method, originally applied to genomics data. Here we propose the version 2 of Cox-nnet, with significant improvement on efficiency and interpretability, making it suitable to predict prognosis based on large-scale electronic medical records (EMR) datasets. We also add permutation-based feature importance scores and the direction of feature coefficients. Applying on an EMR dataset of OPTN kidney transplantation, Cox-nnet v2.0 reduces the training time of Cox-nnet up to 32 folds (n=10,000) and achieves better prediction accuracy than Cox-PH (p<0.05).
Availability and implementation:
Cox-nnet v2.0 is freely available to the public at https://github.com/lanagarmire/Cox-nnet-v2.0
Introduction arge-scale Electronic medical records (EMR) are informative and easily accessible data sources frequently used for patients survival prediction. Prediction models built on EMR data tend to have better performance than those using administrative data (Mahmoudi et al. , 2020). It is also found that machine learning based models outperformed conventional models, such as Cox-Proportional Hazard (Cox-PH) model (Cox, 1972), Random Survival Forests (RSF) model (Ishwaran et al. , 2008) and elastic net regression (Fan et al. , 2010) on the prediction of coronary artery disease mortality using EMR data (Steele et al. , 2018). Although it is challenging to develop prediction models driven by EMR data, the large sample size and clinical features in EMR data provide valuable information in survival prediction (Goldstein et al. , 2017). We previously proposed Cox-nnet (Ching et al. , 2018), a deep learning based neural network prognosis prediction model, which achieved comparable or better performance than Cox-PH on high-throughput omics data. We recently applied Cox-nnet to histopathology imaging data with pre-extracted features, and demonstrated its advantage in combining gene expression data and image data for survival prediction (Zhan et al. ) . However, it remains to be tested if Cox-nnet is suitable to predict survival in large-scale EMR data, where the patient size is usually magnitudes larger than genomics data. Towards this, we propose Cox-nnet v2.0, which significantly improves computational speed, with enhanced interpretability. Additionally, Cox-nnet v2.0 also achieves better prediction accuracy than Cox-PH.
Methods
Cox-nnet method improvement : The original Cox-nnet is a neural-network based extension to Cox-PH method, using the log partial likelihood as its loss function. In Cox-nnet v2.0, we have made the following improvements: (1) Speed-up in calculating log partial likelihood loss function. The log partial likelihood is calculated by: ππππ ( π·π· ) = οΏ½ πΆπΆ ππ = ( π½π½ ππ β ππππππ οΏ½ π‘π‘ ππ β©½π‘π‘ ππ ππππππ ( π½π½ ππ )) here π½π½ ππ is the linear predictor of patient ππ and πΆπΆ ππ is defined by πΆπΆ ππ = πΌπΌ ( patient ππ is not censored ). To avoid nested summation in Theano, the previous version of Cox-nnet calculates the log partial likelihood by matrix multiplication: ππππ ( π·π· ) = { π½π½ β ππππππ ( π π β ππππππ ( π½π½ ))} ππ πΆπΆ Where πΆπΆ and π½π½ are vectors of πΆπΆ ππ and π½π½ ππ respectively. π π is a ππ by ππ at risk set indicator matrix, and each entry π π ππππ is defined by: π π ππππ = πΌπΌ ( π‘π‘ ππ β©½ π‘π‘ ππ ) Where ππ is the sample size of the input data, and π‘π‘ ππ and π‘π‘ ππ are the event time of patient ππ and ππ , respectively. This implementation is memory intensive and time consuming when dealing with large sample sizes. In the new version, instead of pairwise comparison we sorted the observations by event time. Then by definition of the at risk set, π π is converted to an upper triangular matrix filled with 1. Intuitively, π π βππππππ ( π½π½ ) can be calculated using cumulative summation that no longer requires storing π π matrix and nested summation (double loops). (2) Adding permutation based feature importance scores. Previously the variable importance score of Cox-nnet is calculated by pseudo drop-out, which replaced the variable with its mean. The drawback is that it is hard to interpret categorical variables. Here we introduce a more general feature evaluation method using permutation importance score (Breiman, 2001). The main idea is to measure the model error increase after shuffling the featureβs values, since the permutation breaks the relationship between the feature and the outcome. We implement the algorithm proposed in Fisher et al. (2019). (3) Adding the directionality of the feature coefficient. Similar to estimating the sign of π½π½ for Cox-PH, we develop a framework which approximates the direction of feature coefficients in Cox-nnet. The linear predictor in Cox-nnet is: π½π½ ππ = πΊπΊ ( ππππ ππ + ππ ) π½π½ here πΊπΊ is the activation function, ππ is the coefficient weight matrix between input and hidden layer, and ππ is the bias term. Suppose each column ππ ππβ in ππ β is defined by: ππ ππβ = ( ππ ππ β ) β¨πΌπΌ ( ππ ππ is continuous variable ) + β¨ πΌπΌ ( ππ ππ is categorical variable ) Similar to the interpretation of π½π½ in Cox-PH, the direction of each feature coefficient in Cox-nnet is approximated by the sign of ππ β ππππ= π₯π₯π½π½ ππππ = ππ β ππππ= ( π½π½ ππ β π½π½ ππππββ ) = ππ β ππππ= { πΊπΊ ( ππππ ππ + ππ ) π½π½ β πΊπΊ ( ππππ ππππββ + ππ ) π½π½ } . Where ππ ππππββ is defined by ππ ππππββ = ( ππ ππππβ , ππ ππ ( βππ ) ) . Intuitively, the risk increases if the sign of ππ β ππππ= π₯π₯π½π½ ππππ is positive. (4) Adding additional optimization algorithms and activation functions for parameter tuning. We add Adam (Kingma and Ba, 2014) optimizer as an alternative optimization strategy, which further accelerates the training process. We also use the Scaled Exponential Linear Unit (SELU) activation function (Klambauer et al. , 2017) in the Cox-nnet v2.0.
Evaluation Metric s : As in Cox-nnet v1.0, we evaluate the prediction accuracy by C-IPCW (Uno et al. , 2011), which is the C-index weighted by inverse censoring probability. Dataset : The EMR data used for the study is kidney transplant data obtained from the U.S. Organ Procurement and Transplantation Network (OPTN) (https://optn.transplant.hrsa.gov/data/). A total of 80,019 patients which includes all patients with ages greater than 18 who received transplant between January 2005 and January 2013 with deceased donor type were used in the analysis. A total of 117 clinical variables describing up-to transplant characteristics are used in the analysis.
Results
The structure of Cox-nnet v2.0 is shown in Fig. 1A. The newly updated functionalities are highlighted. We randomly split the kidney transplant EMR data into training (80%) and testing (20%) sets, and used C-IPCW to evaluate on the hold-out testing set. We repeated this process 10 times to access the overall prediction performance. Cox-nnet v2.0 is not sensitive to the sample size and dramatically reduces the training time, compared to Cox-nnet v1.0 where the computing time increases polynomially with the ample size (Fig. 1B). Cox-nnet v2.0 also achieves significantly better C-IPCW than Cox-PH (Fig. 1C), without any drop of C-IPCW compared to Cox-nnet v1.0. We performed feature evaluation by calculating the feature importance scores using the new permutation method, where the values are close to those by the previous pseudo drop-out method. With the directionality (+/- signs) of the feature coefficients, our feature evaluation results are more interpretable: a positive (+) sign indicates increased risk of graft failure, whereas a negative (-) sign means decreased risk of graft failure. As additional confirmation, the pattern of important scores matches well with that of coefficients obtained from Cox-PH (Fig. 1D). In summary, Cox-nnet v2.0 significantly accelerates the training process of Cox-nnet without loss in the prediction accuracy. In addition, it also enables better interpretation for all features in the model. Cox-nnet v2.0 is the new version suitable for prognosis prediction in large-scale EMR dataset.
Authorβs Contribution
LXG conceived the project, DW conducted model improvement and data analysis, KH provided the dataset and helped with the analysis. DW and LXG wrote the manuscript with the help of KH. All authors read, revised and approved the manuscript.
Declaration of Conflict of Interest
The authors declare no conflict of interest.
Acknowledgement
References
Breiman,L. (2001) Random Forests.
Mach. Learn. , , 5β32. Ching,T. et al. (2018) Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data. PLoS Comput. Biol. , , e1006076. Cox,D.R. (1972) Regression models and lifeβtables.
J. R. Stat. Soc. Series B Stat. Methodol. , , 187β202. Fan,J. et al. (2010) High-dimensional variable selection for Coxβs proportional hazards model. In, Borrowing strength: Theory powering applications--a Festschrift for Lawrence D. Brown . Institute of Mathematical Statistics, , 70β86. Fisher,A. et al. (2019) All Models are Wrong, but Many are Useful: Learning a Variableβs Importance by Studying an Entire Class of Prediction Models Simultaneously. J. Mach. Learn. Res. , , 1β81. Goldstein,B.A. et al. (2017) Opportunities and challenges in developing risk prediction models with electronic health records data: a systematic review. J. Am. Med. Inform. Assoc. , , 198β208. Ishwaran,H. et al. (2008) Random survival forests. Ann. Appl. Stat. , , 841β860. Kingma,D.P. and Ba,J. (2014) Adam: A Method for Stochastic Optimization. arXiv [cs.LG] . Klambauer,G. et al. (2017) Self-Normalizing Neural Networks. In, Guyon,I. et al. (eds), Advances in Neural Information Processing Systems 30 . Curran Associates, Inc., pp. 971β980. Mahmoudi,E. et al. (2020) Use of electronic medical records in development and validation of risk prediction models of hospital readmission: systematic review.
BMJ , , m958. Steele,A.J. et al. (2018) Machine learning models in electronic health records can outperform conventional survival models for predicting patient mortality in coronary artery disease. PLoS One , , e0202344. Uno,H. et al. (2011) On the C-statistics for evaluating overall adequacy of risk prediction procedures with censored survival data. Stat. Med. , , 1105β1117. Zhan,Z. et al. Two-stage biologically interpretable neural-network models for liver cancer prognosis prediction using histopathology and transcriptomic data.