Generalizable Machine Learning in Neuroscience using Graph Neural Networks
Paul Y. Wang, Sandalika Sapra, Vivek Kurien George, Gabriel A. Silva
GGeneralizable Machine Learning inNeuroscience using Graph Neural Networks
Paul Y Wang (cid:0) , Sandalika Sapra , Vivek Kurien George , and Gabriel A. Silva (cid:0) Department of Physics, University of California San Diego Department of Electrical and Computer Engineering, University of California San Diego Department of Bioengineering, University of California San Diego Department of Neuroscience, University of California San Diego Center for Engineered Natural Intelligence, University of California San Diego
Although a number of studies have explored deep learn-ing in neuroscience, the application of these algorithms to neu-ral systems on a microscopic scale, i.e. parameters relevantto lower scales of organization, remains relatively novel. Mo-tivated by advances in whole-brain imaging, we examined theperformance of deep learning models on microscopic neural dy-namics and resulting emergent behaviors using calcium imagingdata from the nematode C. elegans. We show that neural net-works perform remarkably well on both neuron-level dynam-ics prediction, and behavioral state classification. In addition,we compared the performance of structure agnostic neural net-works and graph neural networks to investigate if graph struc-ture can be exploited as a favourable inductive bias. To performthis experiment, we designed a graph neural network whichexplicitly infers relations between neurons from neural activ-ity and leverages the inferred graph structure during computa-tions. In our experiments, we found that graph neural networksgenerally outperformed structure agnostic models and excel ingeneralization on unseen organisms, implying a potential pathto generalizable machine learning in neuroscience.
Graph neural networks | Neuroscience | C elegans | Machine learningCorrespondence:
PW ([email protected]) and GS ([email protected])
Introduction
Constructing generalizable models in neuroscience poses asignificant challenge because systems in neuroscience aretypically complex in the sense that dynamical systems com-posed of numerous components collectively participate toproduce emergent behaviors. Analyzing these systems canbe difficult because they tend to be highly non-linear in howthey interact, can exhibit chaotic behaviors and are high-dimensional by definition. As such, indistinguishable macro-scopic states can arise from numerous unique combinationsof microscopic parameters, i.e. parameters relevant to lowerscales of organization. Thus, bottom-up approaches to mod-eling neural systems often fail since a large number of mi-croscopic configurations can lead to the same observables (1)(2).Because neural systems are highly degenerate and com-plex, their analysis is not amenable to many conventional al-gorithms. For example, observed correlations between in-dividual neurons and behavioral states of an organism maynot generalize to other organisms or even to repeated trialsin the same individual (3) (4) (5). Hence, individual vari-ability of neural dynamics remains poorly understood and a fundamental obstacle to model development, as evaluationon unseen individuals often leads to subpar results. Never-theless, neural systems exhibit universal behavior: organismsbehave similarly. Motivated by the need for robust and gen-eralizable analytical techniques, researchers recently appliedtools from dynamical systems analysis to simple organismsin hopes of discovering a universal organizational principleunderlying behavior. These studies, made possible by ad-vances in whole-brain imaging, reveal that neural dynamicslive on low-dimensional manifolds which map to behavioralstates (6) (7). This discovery implies that although micro-scopic neural dynamics differ between organisms, a macro-scopic/global universal framework may enable generalizablealgorithms in neuroscience. Nevertheless, the need for sig-nificant hand-engineered feature extraction in these studiesunderscores the potential of deep learning models for scal-able analysis of neural dynamics.In this work, we examine the performance and general-izability of deep learning models applied to the neural ac-tivity of C. elegans (round worm/nematode). In particular,C. elegans is a canonical species for investigating micro-scopic neural dynamics because it remains the only organ-ism whose connectome (the mapping of all 302 neurons andtheir synaptic connections) is completely known and wellstudied (8) (9) (10) (11). Furthermore, the transparent bodyof these worms allows for calcium imaging of whole brainneural activity which remains the only imaging technique ca-pable of spatially resolving the dynamics of individual neu-rons (12). Leveraging these characteristics and insight gainedfrom previous studies, we developed deep learning modelsthat bridge recent advances in neuroscience and deep learn-ing. Specifically, we first demonstrate state-of-the-art perfor-mance for classifying motor action states of C. elegans fromcalcium imaging data acquired in previous works. Next, weexamine the generalization performance of our deep learn-ing models on unseen worms both within the same study andin worms from a separate study published years later. Wethen show that graph neural networks exhibit a favourableinductive bias for analyzing both higher-order function andmicroscopic/neuron-level dynamics in C. elegans.
Background
In this section we discuss recent advances in neuroscienceand machine learning upon which we build our model and
Wang et al. | arXiv | October 20, 2020 | 1–9 a r X i v : . [ c s . L G ] O c t xperiments. Universality/Generalizability in C. elegans models.
Themotor action sequence of C. elegans is one of the only sys-tems for which experiments on whole-brain microscopic neu-ral activity may be performed and readily analyzed. As such,numerous efforts have focused on building models that canaccurately capture the hierarchical nature of neural dynamicsand resulting locomotive behaviors (13) (14). Taking advan-tage of this, Kato et. al. (7) investigated neural dynamics cor-responding to a pirouette, a motor action sequence in whichworms switch from forward to backward crawling, turn, andthen continue forward crawling. Their analysis showed thatmost variations ( ∼ Graph Neural Networks.
Graph Neural Networks (GNNs)are a class of neural networks that explicitly use graph struc-ture during computations through message passing algo-
Fig. 1. (A)
Calcium signals recorded in one animal for ∼
15 minutes by (7). Eachrow represents a single neuron. The top 15 rows (above the red line) correspond toneurons unambiguously identified in all animals (shared neurons). (B)
Sample tracewith corresponding behavioral state colored. (C)
Neural dynamics of two neuronsfor specific behavior states. Colored solid lines are the mean activity for each ani-mal, and the black dashed line is the mean activity for all animals. Shaded coloredregions show 95% confidence intervals. (D)
Probabilities that neural dynamics fromdifferent individuals were drawn from the same distribution. (E)
Attempt by (18) todecode onset of backwards locomotion using neural dynamics for each animal andaveraged neural dynamics across other four animals. Reproduced with permissionfrom (18).
Fig. 2. (A)
Rendering of calcium imaging experiment where activity of neurons inthe head of the worm is recorded. Coloured arrows show main motor action be-havioral states. (B) and (C)
Resulting manifold from (18). (B)
Manifold constructedfrom activity of four worms with coloured lines indicating neural activity of fifth worm. (C)
Manifold constructed from neural activity of uniquely identified neurons ( n =15)shared among all 5 worms. Black arrows correspond to cyclical transition of mo-tor action sequence and colors correspond to motor action states. Modified withpermission from (18).2 | arXiv Wang et al. | Generalizable Machine Learning using GNNs ithms where features are passed along edges between nodesand then aggregated for each node ((19); (20)). These net-works were inspired by the success of convolutional neuralnetworks in the domain of two-dimensional image process-ing and failures when extending conventional convolutionalnetworks to non-euclidean domains ((21)). In essence, be-cause graphs can have arbitrary structure, the inductive biasof convolutional neural networks (equivariance to transla-tional transformations (22)) often breaks down when appliedto graphs. Addressing this issue, an early work on GNNsshowed that one-hop message passing approximates spectralconvolutions on graphs ((23)). Subsequent works have ex-amined the representational power of GNNs in relation to theWeisfeiler-Lehman isomorphism test (24) and limitations ofGNNs when learning graph moments ((25)). From an appliedperspective, GNNs have been widely successful in a wide va-riety of domains including relational inference ((26); (27);(28)), node classification (23) (29)), point cloud segmenta-tion (30), and traffic forecasting ((31) (32). In neuroscience,GNNs have been used on various tasks such as annotatingcognitive state (33), and several frameworks based on graphneural networks have been proposed for analyzing fMRI data((34); (35)). Relational Inference.
Relational inference remains a long-standing challenge with early works in neuroscience seek-ing to quantify correlations between neurons (36). Modernapproaches to relational inference employ graph neural net-works as their explicit reliance on graph structure forms arelational inductive bias (37) (21). In particular, our modelis inspired by the Neural Relational Inference model (NRI)which uses a variational autoencoder for generating edgesand a decoder for predicting trajectories of each object in asystem (26). By inferring edges, the NRI model explicitlycaptures interactions between objects and leverages the re-sulting graph as an inductive bias for various machine learn-ing tasks. This model was successfully used to predict thetrajectories of coupled Kuramoto oscillators, particles con-nected by springs, the pick and roll play from basketball,and motion capture visualizations. Subsequently, the authorsdeveloped Amortized Causal Discovery, a framework basedon the NRI model which infers causal relations from time-dependent data (27).
Deep Learning in Neuroscience.
With the success of con-volutional neural networks, researchers successfully applieddeep learning to numerous domains in neuroscience ((38))including MRI imaging (39) and connectomes (40) where al-gorithms can predict disorders such as autism (41). Similarly,brain-computer interfaces (BCI) are a well-studied field re-lated to our work as they focus on decoding macroscopic vari-ables from measurements of neural activity (42). These stud-ies generally involve fMRI or EEG data, which character-ize neural activity on a population level, to varying amountsof success (43) (44) (45) (46). Regardless, a challenge forthe field is developing generalizable algorithms to individu-als unseen during training (47).
Fig. 3. (A)
Visualization of temporal graph. Inset shows x n plotted against t wherethe top is the calcium trace, and the bottom is its derivative. The dashed line inter-cepts the feature vectors at t = t +1 and denotes x t +1 n . (B) and (C) are simplifiedvisualizations of the MLP and GNN modules respectively. Model
In this section, we first present the general framework of ourbehavioral state classification and trajectory prediction mod-els. Next, we detail the implementation of our neural networkmodules.
Framework.
We define the set of trajectories (calcium imag-ing traces) for each worm as X α = { x , ..., x n } where α denotes the label of the individual, n the name of the neuron,and x n the feature vector of the neuron. In our case, x n cor-responds to time-dependent normalized calcium traces andtheir derivatives for each neuron. Likewise, x tn correspondsto the feature(s) of neuron n at timestep t . Finally, the behav-ioral states of an individual are encoded as a α = ( a , ..., a t ) where a behavioral state a is assigned for each timestep t .Separate models were developed for each task: behav-ioral state classification and trajectory prediction. In bothcases, data from a worm α is structured as a temporal graph G α = ( G α , ..., G tα ) (Figure 3A) where each timestep is repre-sented by a static graph whose nodes correspond to neurons.Following the notation above, the trajectories of each neu-ron’s calcium traces are encoded as node features x n , andthe behavioral state of an individual is interpreted as a graphfeature a tα . For behavioral state classification, our model con-sists of the following (we omit α and t in intermediary stepsto simplify notation): H = f ( X tα ) (1) p = Sof tmax ( H ) (2) ˆ a tα = M ax ( p ) (3) where f is an universal approximator/neural network module(described in the next section), H are hidden features, p isthe probability that a system is in one of k states, and ˆ a tα isthe most probable/predicted state.For trajectory prediction, we developed a Markovianmodel for inferring trajectories of a consecutive timestep: H = f ( X tα ) (4) Wang et al. | Generalizable Machine Learning using GNNs arXiv | 3 t +1 α = X t + H (5) where H and f are the same as before. We also experimentedwith non-Markovian models (RNNs) for which a hidden stateis included for each timestep.The structure of our models allows us to substitute var-ious modules for f . While we include results from sev-eral neural networks, we focus on two representative models:a multi-layer perceptron (MLP) agnostic to graph structure(Figure 3B) and a graph neural network (GNN) which ex-plicitly computes on an inferred graph (Figure 3C). Neural Network Modules: MLP and GNN.
Our MLPmodule aggregates the features of a graph and feeds the ag-gregated features into a two-layer MLP neural network: H = Aggregation ( x t , ..., x tn ) (6) H tout = g ( H ) (7) where g is a MLP. Contrasting the MLP module, our GNNrelies on message passing between connected nodes and con-tains an encoder for edge weights w ij : H = g enc ( X t ) (8) H ij = g ( Aggregation ( h i , h j )) (9) p ij = Sof tmax ( h i , h j ) (10) w tij = p ij (11) where (9) encodes a hidden representation H ij for the edges.Applying the softmax function to H ij produces a two di-mensional probability vector normalized to 1. We define thesecond dimension p ij as the weight w ij of an edge betweennodes i and j . The edge weights either dynamically changein each timestep’s inferred graph G t or remain fixed for thewhole temporal graph G of an individual worm. If the edgesare static for the temporal graph, the aggregation step in (9)also averages hidden features across all timesteps.After edges are encoded, the GNN performs a messagepassing and aggregation step: H i = N X j w tij x tj (12) H tout = g ( Aggregation ( H i )) (13) The sum is performed over all nodes in the graph such thatweighted messages are passed between connected nodes andpotentially along self edges. The message passing step (12)can also be formulated in terms of an inferred weighted adja-cency matrix A t and node features X t : H t = A t X t (14) Theoretically, an arbitrary number of message passing stepscan be implemented; however, we did not find any improve-ments when using more than one step. In addition, we findthat performance improves when using concatenation insteadof summation during the aggregation step.
Experiments
Data.
Our experiments were performed with data acquired in(7) and (15). We summarize various details about the data inthis section; however, we direct the reader to each respectivepublication for specific experimental details.
Calcium Imaging.
Kato et. al. (7) showed that neural ac-tivity corresponding to the motor action sequence lives onlow dimensional manifolds. To record neuron level dynam-ics, they did whole-brain genetically encoded Ca imag-ing with single-cell-resolution and measured ∼
100 neuronsfor around 18 minutes. They then normalized each calciumtrace by peak fluorescence and identified neurons using spa-tial position and previous literature (48). Aside from imagingfreely moving worms, the authors also examined robustnessof topological features to sensory stimuli changes, hub neu-ron silencing, and immobilization. For simplicity, we limitedour experiments to data collected on freely moving worms.Nichols et. al. (15) focused on differences in neural ac-tivity of C. elegans while awake or asleep and studied twodifferent strains of worms, n2 ( n =11) and npr1 ( n =10). Be-cause experiments in both studies were performed by thesame group, most experimental procedures were similar, al-lowing us to easily process data to match the Kato dataset.While this dataset includes imaging data of each worm dur-ing quiescence, for consistency with the Kato dataset, weonly included data before sleep was induced. Furthermore,we combined results for both strains of worms as we did notnotice any statistically relevant differences between them. Data Processing.
We normalized the calcium trace and itsderivative of each neuron to [0,1]. Normalization was per-formed for the entire recorded calcium trace of a worminstead of within each batch because the relative magni-tudes of the traces have been found to contain graded in-formation about the worm’s behavioral state (eg. crawlingspeed). To create training batches, we separated each calciumtrace of approximately 3000-4000 timesteps into batches of 8timesteps where each timestep corresponds to roughly 1/3 ofa second. We chose batch sizes of 8 timesteps because visu-alization of calcium traces showed that most local variationsoccur within this time frame. Moreover, 8 timesteps roughlycorresponds to 3 seconds which is about the amount of timea worm needs to execute a behavioral change. Finally, thebatches were shuffled before being divided into 10 folds laterused for cross-validation, ensuring that each fold is represen-tative across the whole dataset.To compare with previous works, we performed ourexperiments on uniquely identified neurons between thedatasets that we investigated. Identifying specific neurons isan experimental challenge, and as such, only a small frac-tion of neurons were unequivocally labeled. A total of 15neurons were uniquely identified between all worms ( n =5)measured in the Kato dataset: (AIBL, AIBR, ALA, AVAL,AVAR, AVBL, AVER, RID, RIML, RIMR, RMED, RMEL,RMER, VB01, VB02). In addition, the Nichols dataset con-tained data from 21 worms with 3 uniquely identified neurons et al. | Generalizable Machine Learning using GNNs able 1. Classification Accuracy of Forward and Reverse Crawling
Training Set Evaluation Set (Kato) Evaluation Set (Nichols)(18) 83 81 —SVM 98.8 ± .4 82.8 ± ± ± .6 93.9 ± ± ± .6 96.8 ± ± ± .5 97.7 ± ± shared among all worms in both datasets: (AIBR, AVAL,VB02). Results.
Following (18), we used data from (7) for train-ing/evaluating our models and data from (15) as an extendedevaluation set. Because whole brain imaging is incrediblydifficult, our datasets were relatively small. To address this,we experimented with data augmentation by combining datafrom multiple worms in the Kato dataset during model train-ing. For all experiments, we performed 10-fold cross valida-tion on all permutations of worms in our training set. Moredetails, along with supplemental experiments, can be foundin the Supplementary Information.
Behavioral State Classification.
Our first experiment com-pared the performance of our models to state-of-the-art re-sults reported in (18). Specifically, this experiment involvedthe classification of only two motor action states, forwardand reverse crawling. Along with our models describedabove, we also experimented with a support vector machine(SVM) and a GNN which computes with edges derived fromthe physical connectome (8). In particular, we incorporatedthe connectome into our model to investigate whether phys-ical/structural connections between neurons can serve as afavourable inductive bias for our GNN. Our results are shownin Table 1 where Training Set denotes test set accuracy af-ter training on the same worm and Evaluation Set denotesevaluation/generalization accuracy on worms unseen duringtraining.Our deep learning models clearly outperformed the SVMand state-of-the-art results, demonstrating the ability of ourmodels to successfully classify behavioral states and general-ize to other worms. Interestingly, the SVM matched the per-formance of our deep learning models on test set accuracy;however, its generalization performance on unseen individu-als was significantly worse than our deep learning models. Assuch, the SVM distinctly illustrates challenges of individualvariability for model development in neural systems despitethe simplicity of our experiments which involve the same setof unequivocally identified neurons. Similarly, our GNN us-ing edges derived from the connectome performed well onthe test set but generalized worse than when using inferrededges. We hypothesize that the detrimental effect of usingthe connectome may be attributed to the model’s lack of ex-pressiveness and the distinction between inferred/functionaland structural connectivity (See S.1.4.3).Following the previous experiment, we applied our MLPand GNN models to the harder task of classifying all behav-ioral states labeled in the Kato dataset (Figure 4A). Withinthis dataset, 7 states were labeled: Forward Crawling, For-ward Slowing, Reverse 1, Reverse 2, Sustained Reverse
Fig. 4. (A)
Classification accuracy of our GNN and MLP models where black verticallines show statistical spread.
Left : Classification of 7 motor action states within theKato dataset.
Right : Classification of 4 motor action states on both the Kato andNichols datasets. (B)
Confusion matrix. Percent occurrence of predicted statesagainst labeled states when evaluating on the Nichols dataset. (C)
Mean squarederror (MSE) of the GNN and various MLP models evaluated on the Nichols dataset.All models were trained using data from one worm or five worms in the Kato Dataset. (D)
Table of MSE values for all models for 1, 8, and 16 timesteps.
Crawling, Dorsal Turn, and Ventral Turn. In comparison tothe Kato dataset, only 4 states were labeled in the Nicholsdataset: reverse crawling, forward crawling, ventral turn, anddorsal turn. For compatibility, we mapped the 7 states of theKato dataset to 4 states of the Nichols dataset when using theNichols dataset as an extended evaluation set.Despite the harder task of classifying 7 states, our mod-els achieved a classification accuracy of ∼
92% on the sameworm (Figure 4A: Left). Moreover, our GNN trained onthree worms in the Kato dataset generalized with an accu-racy of 87% (Figure 4A: Right) when classifying 4 stateson the remaining unseen worms. This substantially exceedsthe performance of our MLP model and (18) who report a81% cross-animal accuracy on two states. Nevertheless, bothMLP and GNN models generalized equally well ( ∼ Neuron-level Trajectory Prediction.
For trajectory prediction,we predicted each neuron’s calcium trace and its derivative(normalized to [0,1]) for 8 timesteps during training and 16timesteps during evaluation/validation. While training ourMarkovian models, scheduled sampling was performed tominimize the accumulation of error (49). In addition to ourMarkovian models, we also experimented with RNN imple-mentations trained with burn-in periods of four timesteps.For evaluation, we averaged the loss per prediction timestepacross all batches. Our experiments primarily focused ongeneralization performance of our models on the extendedevaluation/Nichols dataset (Figure 4C).Predicting neuron-level trajectory using deep learning isfairly novel since advances in whole-brain imaging are re-cent and limited to few organisms. Because calcium traces
Wang et al. | Generalizable Machine Learning using GNNs arXiv | 5 re notoriously noisy and our dataset is relatively small, theperformance of our model is poor; however, inspecting theMSE as a function of prediction step (Figure 4C) demon-strates that all deep learning models are able to learn tran-sitions in the system. Moreover, increasing the number ofworms included during training also improved generalizationperformance of our MLP and GNN models. Perhaps mostsurprising, our Markovian GNN outperformed all MLP mod-els and their derived RNN variants. We attribute this result tothe largely deterministic nature of neural dynamics, charac-terized by sparse bifurcations on the latent manifold, and theinductive bias of GNNs. As a result, given a single timestep,our GNN model was able to predict future trajectories onunseen worms for at least 16 timesteps and clearly outper-formed all other models.
Discussion
For both tasks, our GNN consistently matched or exceededour MLP model which we accredit to its favourable induc-tive bias. Kato et. al. (7) established that projecting neuraldynamics onto three principal components for each worm re-veals universal topological structures; however, attempts toproject neural dynamics onto shared principal components ofall worms failed to display any meaningful structure. Thus,variability in each worm’s neural activity, corresponding tolow dimensional manifolds in latent space, is represented bydifferent linear combinations of neurons. In other words, rel-evant topological structures in latent space are loosely relatedby linear transformations of node features. We speculate thatour GNN’s performance stems from its explicit structure ofmessage passing along inferred edges which is analogous tolearning linear transformations of node features (see equation(14)).Interestingly, our model’s performance was not signifi-cantly impacted by using 3 neurons ( ∼
1% of all neurons)instead of 15 ( ∼
5% of all neurons). This is not surprising be-cause neurons strongly coupled to the motor action sequenceretain most information (50), a fact consistent with (18) whofound that strategically choosing 1 neuron retains ∼
75% ofthe information contained in the larger set of 15 neurons.Finally, as a critical question, we ask whether our model’sperformance stems from choosing a stereotyped organismthat is well studied and biologically simple, or if our resultsimply a path towards generalizable/universal machine learn-ing in neural systems. While the neurophysiology of C. el-egans is quite complex, the motor action sequence we stud-ied is relatively simple, especially in comparison to other or-ganisms and cognitive functions. Moreover, organisms areadaptive and capable of learning new behavior, a fact notrepresented in our dataset. However, a recent astoundingstudy (51) measured neural dynamics in monkeys trained toperform action sequences and determined that learned latentdynamics live in low-dimensional manifolds that were con-served throughout the length of the study. By aligning la-tent dynamics, their model accurately decoded the action ofmonkeys up to two years after the model was trained despitechanges in biology (eg. neuron turnover, adaptation to im- plants). Consequently, we posit that techniques similar tothose used in our model may broadly apply to more complexorganisms and functions.
Conclusion
In this study, we examined the ability of neural networks toclassify higher-order function and predict neuron level dy-namics. In addition, inspired by global organizational prin-ciples of behavior discovered in previous studies, we demon-strated the ability of neural networks to generalize to unseenorganisms. Specifically, our models exceeded the perfor-mance of previous studies in behavioral state classificationof C. elegans. Furthermore, our models successfully gener-alized to unseen organisms, both within the same study, andin a separate experiment spaced years apart. We found thata simple MLP performs remarkably well on unseen organ-isms. Nevertheless, our graph neural network, which explic-itly learns linear transformations of node features, matchedor exceeded the performance of graph agnostic models in allexperiments.We note that our results of generalization on both higher-order functions and neuron-level dynamics (macroscopic andmicroscopic) suggests wide applicability of our techniqueto numerous machine learning tasks in neuroscience andhierarchical dynamical systems. A promising research di-rection is the hierarchical relationship between neuron-leveland population-level dynamics. Breakthroughs in this di-rection may inform machine learning models working withpopulation-level functional and imaging techniques, such asEEG or fMRI, which are readily available and widespread.In addition, in this study, we only focused on simple ma-chine learning tasks and imaging data taken under similarexperimental conditions. Further studies may involve morecomplex tasks such as those involving graded information inneural dynamics, changes in sensory stimuli, acquisition oflearned behaviors, and higher-order functions comprised ofcomplicated sequences of behavior. From a machine learningperspective, the development of a recurrent graph neural net-work with a suitable attention kernel may greatly aid modelperformance. Moreover, additional work is needed in exam-ining and improving model performance on arbitrary sets ofneurons as neuron identification is experimentally challeng-ing and limited to small systems. Finally, our results showthat data augmentation through the inclusion of more individ-uals can significantly improve generalization performance inmicroscopic neural systems.
Conflict of Interest Statement
The authors declare that the research was conducted in the ab-sence of any commercial or financial relationships that couldbe construed as a potential conflict of interest.
Author Contributions
Experiments and models were conceived by P.Y. Wang. S.Sapra assisted with the implementation of various algorithms. et al. | Generalizable Machine Learning using GNNs he manuscript was written and revised after numerous iter-ations by all the authors.
Funding
This work was supported by unrestricted funds to the Centerfor Engineered Natural Intelligence.
Acknowledgments
The authors thank the authors of (18) for graciously allowingreproductions of their figures. In addition, the authors thankthe Zimmer Lab for making their data available online (datafrom (7) and (15) can be found here). P.Y. Wang is gratefulto Ilya Valmianski for insightful discussion and guidance.
Data Availability Statement
The datasets (7) and (15) analyzed for this study can be foundin the OSF repository here.
Bibliography
1. Jorge Golowasch, Mark S Goldman, LF Abbott, and Eve Marder. Failure of averaging inthe construction of a conductance-based neuron model.
Journal of neurophysiology , 87(2):1129–1131, 2002.2. Astrid A Prinz, Dirk Bucher, and Eve Marder. Similar network activity from disparate circuitparameters.
Nature neuroscience , 7(12):1345–1352, 2004.3. Yves Frégnac. Big data and the industrialization of neuroscience: A safe roadmap forunderstanding the brain?
Science , 358(6362):470–477, 2017.4. Mark M Churchland, John P Cunningham, Matthew T Kaufman, Stephen I Ryu, and Kr-ishna V Shenoy. Cortical preparatory activity: representation of movement or first cog in adynamical machine?
Neuron , 68(3):387–400, 2010.5. Mark S Goldman, Jorge Golowasch, Eve Marder, and L. F. Abbott. Global structure, robust-ness, and modulation of neuronal models.
Journal of Neuroscience , 21(14):5229–5238,2001.6. Robert Prevedel, Young-Gyu Yoon, Maximilian Hoffmann, Nikita Pak, Gordon Wetzstein,Saul Kato, Tina Schrödel, Ramesh Raskar, Manuel Zimmer, Edward S Boyden, et al. Simul-taneous whole-animal 3d imaging of neuronal activity using light-field microscopy.
Naturemethods , 11(7):727–730, 2014.7. Saul Kato, Harris S Kaplan, Tina Schrödel, Susanne Skora, Theodore H Lindsay, EviatarYemini, Shawn Lockery, and Manuel Zimmer. Global brain dynamics embed the motorcommand sequence of caenorhabditis elegans.
Cell , 163(3):656–669, 2015.8. John G White, Eileen Southgate, J Nichol Thomson, and Sydney Brenner. The structure ofthe nervous system of the nematode caenorhabditis elegans.
Philos Trans R Soc Lond BBiol Sci , 314(1165):1–340, 1986.9. Cornelia I Bargmann and Eve Marder. From the connectome to brain function.
Naturemethods , 10(6):483, 2013.10. Lav R Varshney, Beth L Chen, Eric Paniagua, David H Hall, and Dmitri B Chklovskii. Struc-tural properties of the caenorhabditis elegans neuronal network.
PLoS Comput Biol , 7(2):e1001066, 2011.11. Steven J Cook, Travis A Jarrell, Christopher A Brittin, Yi Wang, Adam E Bloniarz, Maksim AYakovlev, Ken CQ Nguyen, Leo T-H Tang, Emily A Bayer, Janet S Duerr, et al. Whole-animalconnectomes of both caenorhabditis elegans sexes.
Nature , 571(7763):63–71, 2019.12. Chentao Wen and Koutarou D Kimura. How do we know how the brain works?—analyzingwhole brain activities with classic mathematical and machine learning methods.
JapaneseJournal of Applied Physics , 59(3):030501, 2020.13. Gopal P Sarma, Chee Wai Lee, Tom Portegys, Vahid Ghayoomie, Travis Jacobs, Bradly Al-icea, Matteo Cantarelli, Michael Currie, Richard C Gerkin, Shane Gingell, et al. Openworm:overview and recent advances in integrative biological simulation of caenorhabditis elegans.
Philosophical Transactions of the Royal Society B , 373(1758):20170382, 2018.14. Padraig Gleeson, David Lung, Radu Grosu, Ramin Hasani, and Stephen D Larson. c302:a multiscale framework for modelling the nervous system of caenorhabditis elegans.
Philo-sophical Transactions of the Royal Society B: Biological Sciences , 373(1758):20170379,2018.15. Annika LA Nichols, Tomáš Eichler, Richard Latham, and Manuel Zimmer. A global brainstate underlies c. elegans sleep behavior.
Science , 356(6344), 2017.16. Harris S Kaplan, Oriana Salazar Thula, Niklas Khoss, and Manuel Zimmer. Nested neuronaldynamics orchestrate a behavioral hierarchy across timescales.
Neuron , 105(3):562–576,2020.17. Susanne Skora, Fanny Mende, and Manuel Zimmer. Energy scarcity promotes a brain-widesleep state modulated by insulin signaling in c. elegans.
Cell reports , 22(4):953–966, 2018.18. Connor Brennan and Alexander Proekt. A quantitative model of conserved macroscopicdynamics predicts future motor commands.
Elife , 8:e46814, 2019.19. F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner, and G. Monfardini. The graph neuralnetwork model.
IEEE Transactions on Neural Networks , 20(1):61–80, 2009. 20. Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl.Neural message passing for quantum chemistry. In
Proceedings of the 34th InternationalConference on Machine Learning-Volume 70 , pages 1263–1272, 2017.21. Peter Battaglia, Jessica Blake Chandler Hamrick, Victor Bapst, Alvaro Sanchez, ViniciusZambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, RyanFaulkner, Caglar Gulcehre, Francis Song, Andy Ballard, Justin Gilmer, George E. Dahl,Ashish Vaswani, Kelsey Allen, Charles Nash, Victoria Jayne Langston, Chris Dyer, NicolasHeess, Daan Wierstra, Pushmeet Kohli, Matt Botvinick, Oriol Vinyals, Yujia Li, and RazvanPascanu. Relational inductive biases, deep learning, and graph networks. arXiv , 2018.22. Taco Cohen and Max Welling. Group equivariant convolutional networks. In
Internationalconference on machine learning , pages 2990–2999, 2016.23. Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutionalnetworks. arXiv preprint arXiv:1609.02907 , 2016.24. Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graphneural networks? arXiv preprint arXiv:1810.00826 , 2018.25. Nima Dehmamy, Albert-László Barabási, and Rose Yu. Understanding the representationpower of graph neural networks in learning graph topology. In
Advances in Neural Informa-tion Processing Systems , pages 15413–15423, 2019.26. Thomas Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling, and Richard Zemel. Neural re-lational inference for interacting systems. In
International Conference on Machine Learning ,pages 2688–2697, 2018.27. Sindy Löwe, David Madras, Richard Zemel, and Max Welling. Amortized causal discovery:Learning to infer causal graphs from time-series data, 2020.28. David Raposo, Adam Santoro, David Barrett, Razvan Pascanu, Timothy Lillicrap, and PeterBattaglia. Discovering objects and their relations from entangled scene representations. arXiv preprint arXiv:1702.05068 , 2017.29. Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on largegraphs. In
Advances in neural information processing systems , pages 1024–1034, 2017.30. Yue Wang, Yongbin Sun, Ziwei Liu, Sanjay E Sarma, Michael M Bronstein, and Justin MSolomon. Dynamic graph cnn for learning on point clouds.
Acm Transactions On Graphics(tog) , 38(5):1–12, 2019.31. Bing Yu, Haoteng Yin, and Zhanxing Zhu. Spatio-temporal graph convolutional networks: Adeep learning framework for traffic forecasting. arXiv preprint arXiv:1709.04875 , 2017.32. Yaguang Li, Rose Yu, Cyrus Shahabi, and Yan Liu. Diffusion convolutional recurrent neuralnetwork: Data-driven traffic forecasting. arXiv preprint arXiv:1707.01926 , 2017.33. Yu Zhang and Pierre Bellec. Functional annotation of human cognitive states using graphconvolution networks. 2019.34. Xiaoxiao Li and James Duncan. Braingnn: Interpretable brain graph neural network for fmrianalysis. bioRxiv , 2020.35. Byung-Hoon Kim and Jong Chul Ye. Understanding graph isomorphism network for brainmr functional connectivity analysis. arXiv preprint arXiv:2001.03690 , 2020.36. Clive WJ Granger. Investigating causal relations by econometric models and cross-spectralmethods.
Econometrica: journal of the Econometric Society , pages 424–438, 1969.37. Peter Battaglia, Razvan Pascanu, Matthew Lai, Danilo Jimenez Rezende, et al. Interactionnetworks for learning about objects, relations and physics. In
Advances in neural informationprocessing systems , pages 4502–4510, 2016.38. Joshua I Glaser, Ari S Benjamin, Roozbeh Farhoodi, and Konrad P Kording. The rolesof supervised machine learning in systems neuroscience.
Progress in neurobiology , 175:126–137, 2019.39. Alexander Selvikvåg Lundervold and Arvid Lundervold. An overview of deep learning inmedical imaging focusing on mri.
Zeitschrift für Medizinische Physik , 29(2):102–127, 2019.40. Colin J Brown and Ghassan Hamarneh. Machine learning on human connectome data frommri. arXiv preprint arXiv:1611.08699 , 2016.41. Colin J Brown, Jeremy Kawahara, and Ghassan Hamarneh. Connectome priors in deepneural networks to predict autism. In , pages 110–113. IEEE, 2018.42. Gabriel A Silva. A New Frontier: The Convergence of Nanotechnology, Brain MachineInterfaces, and Artificial Intelligence.
Frontiers in Neuroscience , 12:843, 2018. ISSN 1662-4548. doi: .43. Pouya Bashivan, Irina Rish, Mohammed Yeasin, and Noel Codella. Learning repre-sentations from eeg with deep recurrent-convolutional neural networks. arXiv preprintarXiv:1511.06448 , 2015.44. No-Sang Kwak, Klaus-Robert Müller, and Seong-Whan Lee. A convolutional neural networkfor steady state visual evoked potential classification under ambulatory environment.
PloSone , 12(2):e0172578, 2017.45. Arthur Mensch, Julien Mairal, Danilo Bzdok, Bertrand Thirion, and Gaël Varoquaux. Learn-ing neural representations of human cognition across many fmri studies. In
Advances inneural information processing systems , pages 5883–5893, 2017.46. Joseph G Makin, David A Moses, and Edward F Chang. Machine translation of corticalactivity to text with an encoder–decoder framework. Technical report, Nature PublishingGroup, 2020.47. Xiang Zhang, Lina Yao, Xianzhi Wang, Jessica Monaghan, David Mcalpine, and Yu Zhang.A survey on deep learning based brain computer interface: Recent advances and newfrontiers. arXiv preprint arXiv:1905.04149 , 2019.48. Z. F. Altun, L. A. Herndon, C. A. Wolkow, C. Crocker, R. Lints, and D. H. Hall. Worm atlas,2002-2020.49. Samy Bengio, Oriol Vinyals, Navdeep Jaitly, and Noam Shazeer. Scheduled sampling forsequence prediction with recurrent neural networks. In
Advances in Neural InformationProcessing Systems , pages 1171–1179, 2015.50. Peiran Gao and Surya Ganguli. On simplicity and complexity in the brave new world oflarge-scale neuroscience.
Current opinion in neurobiology , 32:148–155, 2015.51. Juan A Gallego, Matthew G Perich, Raeed H Chowdhury, Sara A Solla, and Lee E Miller.Long-term stability of cortical population dynamics underlying consistent behavior.
Natureneuroscience , 23(2):260–270, 2020.52. Matthias Fey and Jan E. Lenssen. Fast graph representation learning with PyTorch Geo-metric. In
ICLR Workshop on Representation Learning on Graphs and Manifolds , 2019.
Wang et al.et al.
Wang et al.et al. | Generalizable Machine Learning using GNNs arXiv | 7
3. Petar Veliˇckovi´c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, andYoshua Bengio. Graph attention networks. arXiv preprint arXiv:1710.10903 , 2017.54. Yujia Li, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. Gated graph sequenceneural networks. arXiv preprint arXiv:1511.05493 , 2015.55. Weihua Hu, Bowen Liu, Joseph Gomes, Marinka Zitnik, Percy Liang, Vijay Pande,and Jure Leskovec. Strategies for pre-training graph neural networks. arXiv preprintarXiv:1905.12265 , 2019.56. Barry Horwitz. The elusive concept of brain connectivity.
Neuroimage , 19(2):466–470,2003.
Supplementary Information
Model and Experiments
Model Selection.
The two final models included in the maintext were chosen for their performance and simplicity. Never-theless, we experimented with numerous established modelswhich were easily substituted for f . For GNNs, we primarilyused the Pytorch Geometric library (52). Tested modules in-cluded the GIN-0/GIN- (cid:15) (24), Graph Sage (29), GAT (53),and Global Attention (54). In particular, we expected theGIN to outperform the other modules because its expressive-ness has been shown to aid transfer learning (55); however,because our edges are not explicitly known, we essentiallyapplied the GIN on a fully connected graph. Under this for-mulation, the GIN-0 simply symmetrizes node features aftera message passing step which is similar to the aggregationstep of our MLP. We also found that the GIN- (cid:15) was prone tooverfitting. Finally, we tested the GAT which is similar to ourmodel when edges are dynamically inferred each timestep.As a result, we found that the GAT performs equally well ontrajectory prediction but performs slightly worse on behav-ioral state classification. Model Implementation.
The two-layer MLP correspondingto g in the main text comprised of linear layers followed byReLu activation functions. We also applied batch norm onthe output of the two layers. The Node MLP in the main textrefers to individual MLPs for each node. To construct RNNvariants, we added an LTSM unit before the MLP.We performed some minor hyperparameter optimizationas our combinatorial cross-validation was computationallyexpensive. Overall, we found our models relatively robust todifferent hyperparameters. For trajectory prediction, we usedhidden layers with 256 dimensions. On the other hand, forbehavioral state classification, we used hidden layers with 16dimensions. Furthermore, we determined that dynamic edgesevaluation worked better for trajectory prediction; however,globally evaluated edges for each worm resulted in better per-formance for behavioral state classification. Finally, for tra-jectory prediction, we chose to optimize the mean square er-ror (MSE). For behavioral state classification, we optimizedthe negative log likelihood (NLL). Experimental Procedures.
For the extended evaluation set,we chose data from the prelethargus phase, i.e. part of thestage of larval development associated with higher frequencypharyngeal pumping prior to a cessation during which theanimal enters a brief lethargus, where 4 states were labeled: reverse, forward, dorsal turn, and ventral turn. For compati-bility with the training dataset, we mapped reverse 1, reverse2, and sustained reverse crawling to the reverse state. Sim-ilarly, we mapped forward crawling and forward slowing toforward. In addition to the 7 or 4 labeled states, there was an-other labeled state for unknown behavior or quiescence. Thisstate comprised a very small portion of our data, and duringtraining and evaluation, we ignore the result when the targetis unknown.For all experiments in the main text, we performed 10-fold cross validation on all possible permutations of wormsin our training set (Kato dataset). For example, on our ex-periments trained on two worms, the possible permutationsof worms are the following: { (1, 2), (1, 3), (1, 4), (1, 5),(2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5) } . Experiments la-beled with "Train on 2 worms" involved models trained sepa-rately on each of these permutations. Each permutation theninvolved 10-fold cross validation where the test set was leftout when performing hyperparamter optimization. In particu-lar, for our experiments on behavioral state classification, weused 1 fold as the test/"leave-out" set and 1 fold for the val-idation set which was used for optimization and as a metricfor stopping training. On the other hand, our experiments ontrajectory prediction was focused primarily on generalizationperformance instead of test set accuracy so we used 1 fold asthe validation set and evaluated on all worms in the extendedvalidation set (Nichols dataset). As a note, we also attemptedexperiments where data from the extended dataset was usedas a validation set. Under this condition, we found that theMLP performed significantly better; however, we were con-cerned that the MLP was overfitting to the validation set sowe chose not to included those results.We performed our experiments on with an Intel i9 9900kCPU and Nvidia GeForce RTX 2080Ti graphics card. Sinceour models are relatively simple, we were able to train themodel on data from one worm in one batch. Nevertheless, thenumber of worms and cross-validation procedure was verycomputationally expensive. As such, training and evaluat-ing each model required roughly a week or two of continu-ous computation. For optimization, we used the Adams opti-mizer with a learning rate of − . We decayed the learningrate with by a factor of 0.25 if the loss did not improve af-ter 50 epochs. We then trained for 800 epoch and saved themodel with the lowest validation loss. For scheduled sam-pling (used during trajectory prediction), we adopted a lineardecay which terminated at 300 timesteps. Additional Experiments.
We performed numerous experi-ments to verify our results and examine the performance ofour model on diverse machine learning tasks. We did not per-form rigorous cross validation for the following experiments.
Experiments without AVA.
Referees of (18) were concernedwith behavioral state classification where AVA neurons wereincluded. In particular, these neurons were used by (7) todefine behavioral state through trajectory clustering in latentspace. Referees commented that classifying behavioral stateswith neurons used to define those states was akin to circular et al. | Generalizable Machine Learning using GNNs ig. S1.
Time derivatives of calcium traces projected onto each individual organ-ism’s principal components. Distinct loops correspond to manifolds in latent spacewhere colors correspond to behavior assigned in Kato et al. Reproduced with per-mission from (18). reasoning. We would like to note that (7) verified their as-signed behavioral states through recorded videos, minimiz-ing risks that assigned behavioural states differ from reality.Nevertheless, we followed (18) and performed an experimentexcluding AVA neurons in which we found no noticeable dif-ference in model performance.
One-hot encoding of edges.
To enforce a sparsity on theedges, we experimented with one-hot encoding by adding ascaling factor within the softmax. We found that our GNNachieved similar test accuracies as in the main text. However,our GNN failed to generalize well to unseen worms. Follow-ing our discussion in the main text, we believe that one-hotencoding was detrimental to generalization because it effec-tively results in a permutation matrix which simply permutesnode features. This is counter to previous studies where topo-logical structures are related by more general linear transfor-mations.
Comparison of inferred edges to known connectome.
Infer-ring the connectivity between neurons in neural systems re-mains a key challenge in neuroscience. Because C. Elegansis among few organisms whose connectome mostly or com-pletely known, we decided to compare the inferred edges ofour model to the connectome of C. Elegans. Ultimately, wefound no similarities between our inferred edges and the con-nectome.In neuroscience, two types of connectivity are defined:structural and functional/effective. Structural connectivityrefers to physical connections between neurons, whereasfunctional connectivity implies statistical correlations be-tween neurons and effective connectivity validated causalconnections between neurons (56). The development ofmethods for determining functional and, in particular, effec-tive connectivity remains an open challenge and a highly ac-tive area of research . Nonetheless, in the context of C. ele-gans, each worm generally has the same structural connectiv-ity; however, differences in neural activity implies a differentfunctional connectivity exists for unique individuals. Sincethe connectome relates to the structural connectivity, we be-lieve that our inferred edges are a poor proxy for the connec-tome. On a more abstract level, our graph neural networkworks with a subset of neurons such that a inferred edge may not correspond to a direct correlation, but may rather repre-sent higher order correlations with unseen neurons.
Wang et al.et al.