Personalized Survival Prediction with Contextual Explanation Networks
PPersonalized Survival Prediction withContextual Explanation Networks
Maruan Al-Shedivat
Carnegie Mellon University [email protected]
Avinava Dubey
Carnegie Mellon University [email protected]
Eric P. Xing
Carnegie Mellon University [email protected]
Abstract
Accurate and transparent prediction of cancer survival times on the level of indi-vidual patients can inform and improve patient care and treatment practices. Inthis paper, we design a model that concurrently learns to accurately predict patient-specific survival distributions and to explain its predictions in terms of patientattributes such as clinical tests or assessments. Our model is flexible and basedon a recurrent network, can handle various modalities of data including temporalmeasurements, and yet constructs and uses simple explanations in the form ofpatient- and time-specific linear regression. For analysis, we use two publiclyavailable datasets and show that our networks outperform a number of baselines inprediction while providing a way to inspect the reasons behind each prediction.
In survival analysis, the goal is to estimate the occurrence time and the risk of an unfavorable eventin the future (e.g, death of a patient) that can inform our decisions at present time (e.g., help toselect a treatment). The classical models for this task are the Aalen’s additive model [1] and theCox’s proportional hazard model [2], which linearly regress attributes of a patient to the hazardfunction. While suitable for comparing populations of patients, these models were not designed forpatient-specific prediction. By reformulating survival analysis as a multi-task classification problem,Yu et al. [3] show that a set of temporally ordered linear classifiers provides much more accuratepredictions.Here, we follow the same classification approach and show that using deep learning methods furtherimproves predictive performance on survival data. While promising, straightforward use of neuralnetworks leads to black-box predictors that lack transparency offered by the linear models. Toovercome this issue, we employ contextual explanation networks [CEN, 4]—a class of models thatlearn to predict by generating and leveraging intermediate explanations. Explanations here are definedas instance-specific simple (linear) models that not only help to interpret predictions but are selectedby the network to make predictions for each patient at each time interval. CENs can be based onarbitrary deep architectures and can process a variety of input data modalities while interpretingpredictions in terms of selected attributes. As we demonstrate in experiments, this approach attainsboth the best performance as well as interpretability.
First, we present the setup used by Yu et al. [3]. The data is represented by patient-specific attributes, X , and the times of the occurance of event, T . These times are converted into m -dimensional binaryvectors, Y := ( y , . . . , y m ) , that indicate the corresponding follow up time. If the death occurredat time t ∈ [ t i , t i +1 ) , then y j = 0 , ∀ j ≤ i and y k = 1 , ∀ k > i . If the data point was censored (i.e., Machine Learning for Healthcare (ML4H) Workshop, NIPS 2017, Long Beach, CA, USA. a r X i v : . [ c s . L G ] J a n h h h x x x y y y θ θ θ t ∈ [ t , t ) (a) Architecture used for SUPPORT2 c c c h h h h h h x x x y y y θ θ θ t ∈ [ t , t ) (b) Architecture used for PhysioNetFigure 1: CEN architectures used in our survival analysis experiments. Context encoders were time-distributedsingle hidden layer MLP (a) and LSTM (b) that produced inputs for another LSTM over the output time intervals(denoted with h , h , h hidden states respectively). Each hidden state of the output LSTM was used to generatethe corresponding θ t that were further used to construct the log-likelihood for CRF. we lack information for times after t ∈ [ t i , t i +1 ) ), the targets ( y i +1 , . . . , y m ) are regarded as latentvariables. Note that only m + 1 sequences are valid, i.e., assigned non-zero probability by the model,which allows to write the following linear model: p ( Y = ( y , . . . , y m ) | x , Θ ) = exp (cid:0)(cid:80) mt =1 y t x (cid:62) θ t (cid:1)(cid:80) mk =0 exp (cid:0)(cid:80) mt = k +1 x (cid:62) θ t (cid:1) (1)The model is trained by optimizing a regularized log likelihood w.r.t. Θ := { θ t } mt =1 . After training,we get a set of linear models, one for each time interval, used for predicting the survival probability. Here, we take the same structured prediction approach but consider a slightly different setup. Inparticular, we assume that each data instance (patient record) is represented by three variables: the context , C , the attributes , X , and the targets , Y . Our goal is to learn a model, p w ( Y | X , C ) ,parametrized by w that can predict Y from X and C . Note that inputs have two representations, X and C , where X is a set of attributes that will be used to interpret predictions . Contextual explanationnetworks (CENs) are defined as models that assume the following form: Y ∼ p ( Y | X , θ ) , θ ∼ p w ( θ | C ) , p w ( Y | X , C ) = (cid:90) p ( Y | X , θ ) p w ( θ | C ) d θ (2)where p ( Y | X , θ ) is a predictor parametrized by θ . Such predictors are called explanations , sincethey explicitly relate interpretable variables, X , to the targets, Y . The conditional distribution p w ( θ | C ) is called the context encoder processes the context representation, C , and generatesparameters for the explanation, θ .For survival analysis, we want explanations to be in the form of linear CRFs as given in (1). Hence,our contextual networks with CRF-based explanations are defined as follows: θ t ∼ p w ( θ t | C ) , t ∈ { , . . . , m } , Y ∼ p ( Y | X , θ m ) ,p ( Y = ( y , y , . . . , y m ) | x , θ m ) ∝ exp (cid:40) m (cid:88) t =1 y i ( x (cid:62) θ t ) + ω ( y t , y t +1 ) (cid:41) p w ( θ t | C ) := δ ( θ t , φ t w , D ( c )) , φ t w , D ( c ) := α ( h t ) (cid:62) D , h t := RNN( h t − , c ) (3)A few things to note here. First, the model generates explanations for each patient and for eachtime interval. Second, depending on the nature of the context representation, C , CENs process it It is common to have the data to be of multiple representations some of which are low-level or unstructured(e.g., image pixels, sensory inputs), and other are high-level or human-interpretable (e.g., categorical variables).To ensure interpretability, we would like to use deep networks to process the low-level representation (the context ) and construct explanations as context-specific probabilistic models on the high-level features. θ t for each time step using a recurrent encoder (Figure 1). We use a deterministicRNN-based encoder, φ t , that selects parameters for explanations from a global dictionary, D , usingsoft attention (for details on dictionary-based context encoding, see [4]). Finally, the potentialsbetween attributes, x , and targets, y m , are linear functions parameterized by θ m ; the pairwisepotentials between targets, ω ( y i , y i +1 ) , ensure that configurations ( y i = 1 , y i +1 = 0) are improbable(i.e., ω (1 ,
0) = −∞ and ω (0 ,
0) = ω , ω (0 ,
1) = ω , ω (1 ,
1) = ω are learnable parameters).Given these constraints, the likelihood of an uncensored event at time t ∈ [ t j , t j +1 ) is p ( T = t | x , Θ ) = exp m (cid:88) i = j x (cid:62) θ i (cid:44) m (cid:88) k =0 exp (cid:40) m (cid:88) i = k +1 x (cid:62) θ i (cid:41) (4)and the likelihood of an event censored at time t ∈ [ t j , t j +1 ) is p ( T ≥ t | x , Θ ) = m (cid:88) k = j +1 exp (cid:40) m (cid:88) i = k +1 x (cid:62) θ i (cid:41) (cid:44) m (cid:88) k =0 exp (cid:40) m (cid:88) i = k +1 x (cid:62) θ i (cid:41) (5)The joint log-likelihood of the data consists of two parts: L ( Y , X ; Θ ) = (cid:88) i ∈ NC p ( T = t i | x i , Θ ) + (cid:88) j ∈ C p ( T > t j | x j , Θ ) (6)where NC is the set of non-censored instances (for which we know the outcome times, t i ) and Cis the set of censored instances (for which only know the censorship times, t j ). The objective isoptimized using stochastic gradient descent. See [4] for more details. In our experiments, we consider the datasets, models, and metrics as described below. We compareCENs with a number of baselines quantitatively as well as visualize the learned explanations.
Datasets.
We use two publicly available datasets for survival analysis of of the intense care unit(ICU) patients: (a) SUPPORT2 , and (b) data from the PhysioNet 2012 challenge . The data waspreprocessed and used as follows: • SUPPORT2:
The data had 9105 patient records and 73 variables. We selected 50 variables forboth C and X features. Categorical features (such as race or sex ) were one-hot encoded. Thevalues of all features were non-negative, and we filled the missing values with -1. For CRF-basedpredictors, the survival timeline was capped at 3 years and converted into 156 discrete intervalsof 7 days each. We used 7105 patient records for training, 1000 for validation, and 1000 fortesting. • PhysioNet:
The data had 4000 patient records, each represented by a 48-hour irregularlysampled 37-dimensional time-series of different measurements taken during the patient’s stay atthe ICU. We resampled and mean-aggregated the time-series at 30 min frequency. This resultedin a large number of missing values that we filled with 0. The resampled time-series were used asthe context, C , while for the attributes, X , we took the values of the last available measurementfor each variable in the series. For CRF-based predictors, the survival timeline was capped at 60days and converted into 60 discrete intervals. Models.
For baselines, we use the classical Aalen and Cox models and the CRF from [3], where allused X as inputs. Next, we combine CRFs with neural encoders in two ways:(i) We apply CRFs to the outputs from the neural encoders (denoted MLP-CRF and LSTM-CRF,all trainable end-to-end). Similar models have been show very successful in the natural languageapplications [5]. Note that parameters of the CRF layer assign weights to the latent features andare no longer interpretable in terms of the attributes of interest.(ii) We use CENs with CRF-based explanations, that process the context variables, C , using thesame neural networks as in (i) and output parameters for CRFs that act on the attributes, X . http://biostat.mc.vanderbilt.edu/wiki/Main/DataSets . https://physionet.org/challenge/2012/ . able 1: Performance of the classical Cox and Aalen models, CRF-based models, and CENs that use LSTM orMLP for context embedding and CRF for explanations. The numbers are averages from 5-fold cross-validation;the std. are on the order of the least significant digit. @K denotes the temporal quantile, i.e., the time point suchthat K% of the patients in the data have died or were censored before that point. SUPPORT2 PhysioNet Challenge 2012Model Acc@25 Acc@50 Acc@75 RAE Model Acc@25 Acc@50 Acc@75 RAE
Cox . . . . Cox . . . . Aalen . . . . Aalen . . . . CRF . . . . CRF . . . . MLP-CRF . . . . LSTM-CRF . . . . MLP-CEN . . . . LSTM-CEN . . . . Time after leaving hospital (weeks) sfdm2_SIP>=30sfdm2_Coma or Intubca_yeshdayslosavtisstdementia
Patient ID: 3520 (Died)
Time after leaving hospital (weeks)
Patient ID: 1100 (Survived)
Figure 2: Weights of the CEN-generated CRF explanations for two patients from SUPPORT2 dataset for a set ofthe most influential features: dementia (comorbidity), avtisst (avg. TISS, days 3-25), slos (days from studyentry to discharge), hday (day in hospital at study admit), ca yes (the patient had cancer), sfdm2 Coma orIntub (intubated or in coma at month 2), sfdm2 SIP (sickness impact profile score at month 2). Higher weightvalues correspond to higher feature contributions to the risk of death after a given time point.
Time after leaving hospital (weeks) S u r v i v a l p r obab ili t y Survived Died
Figure 3: CEN-predicted survival curvesfor 500 random patients from SUP-PORT2 test set. Color indicates deathwithin 1 year after leaving the hospital.
Metrics.
Following Yu et al. [3], we use two metrics specific tosurvival analysis: (a) accuracy of correctly predicting survivalof a patient at times that correspond to 25%, 50%, and 75%population-level temporal quantiles (i.e., time points such thatthe corresponding percentage of the patients in the data hadtheir time of the last follow up prior to that due to censorshipor death) and (b) the relative absolute error (RAE) between thepredicted and actual time of death for non-censored patients.
Quantitative results.
The results for all models are given inTable 1. Our implementation of the CRF baseline reproduces(and even slightly improves) the performance reported by Yu etal. [3]. CRFs built on representations learned by deep networks(MLP-CRF and LSTM-CRF models) improve upon the plainCRFs but, as we noted, can no longer be interpreted in termsof the original variables. On the other hand, CENs outperformneural CRF models on certain metrics (and closely matchon the others) while providing explanations for the survivalprobability predictions for each patient at each point in time.
Qualitative results.
To inspect predictions of CENs qualitatively, for any given patient, we canvisualize the weights assigned by the corresponding explanation to the respective attributes at eachtime interval. Figure 2 shows explanation weights for a subset of the most influential features fortwo patients from SUPPORT2 dataset who were predicted as survivor and non-survivor. Theseexplanations allow us to better understand patient-specific temporal dynamics of the contributingfactors to the survival rates predicted by the model (Figure 3). This information can be used formodel diagnostics (i.e., help us understand whether we can trust a particular prediction) and as morefine-grained information useful for decision support.4 eferences [1] O.O. Aalen. “A linear regression model for the analysis of life time”. In:
Statistics in Medicine,8(8):907–925 (1989).[2] DR Cox. “Regression Models and Life-Tables”. In:
Journal of the Royal Statistical Society.Series B (Methodological) (1972), pp. 187–220.[3] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin & Vickie Baracos. “Learning patient-specificcancer survival distributions as a sequence of dependent regressors”. In:
Advances in NeuralInformation Processing Systems . 2011, pp. 1845–1853.[4] Maruan Al-Shedivat, Avinava Dubey & Eric P Xing. “Contextual Explanation Networks”. In: arXiv preprint arXiv:1705.10301 (2017).[5] Ronan Collobert, Jason Weston, Léon Bottou, Michael Karlen, Koray Kavukcuoglu & PavelKuksa. “Natural language processing (almost) from scratch”. In: