Improving Graph Property Prediction with Generalized Readout Functions
II MPROVING G RAPH P ROPERTY P REDICTION WITH G ENERALIZED R EADOUT F UNCTIONS
A P
REPRINT
Eric Alcaide
University of BarcelonaBarcelona, Spain A BSTRACT
Graph property prediction is drawing increasing attention in the recent years due to the fact thatgraphs are one of the most general data structures since they can contain an arbitrary number of nodesand connections between them, and it is the backbone for many different tasks like classificationand regression on such kind of data (networks, molecules, knowledge bases, ...). We introducea novel generalized global pooling layer to mitigate the information loss that typically occurs atthe Readout phase in Message-Passing Neural Networks. This novel layer is parametrized by twovalues ( β and p ) which can optionally be learned, and the transformation it performs can revert toseveral already popular readout functions (mean, max and sum) under certain settings, which canbe specified. To showcase the superior expressiveness and performance of this novel technique, wetest it in a popular graph property prediction task by taking the current best-performing architectureand using our readout layer as a drop-in replacement and we report new state of the art results. Thecode to reproduce the experiments can be accessed here: https://github.com/EricAlcaide/generalized-readout-phase Geometric Deep Learning
There has been a recent substantial progress during the past decades on Machine Learningand Deep Learning methods for analyzing large amounts of data to identify meaningful patterns that can be latter usedfor classification, regression, clustering, data generation... Some areas that have experienced such progress includeComputer Vision, Speech Recognition and Natural Language Processing.However, most of the existing machine and deep learning algorithms need structured data defined on Eu-clidean domains (voxels, grids, etc) to work with. Since there exist several fields which work with data that is notdefined on Euclidean domains (physics, biology, complex systems, etc) the need for algorithms that can work withgraph and/or manifold structured data is highlighted. In the recent years, the term Geometric Deep Learning has beencoined to refer the extension of Deep Learning methods to graphs/manifold data.Graphs are one of the most general data structures since they can be used to represent Euclidean as well asNon-Euclidean data. Formally G = ( V (cid:48) , E (cid:48) ) , a graph is composed by a set of vertices V (also known as the set ofnodes N ) representing the data points and a set of edges E which model the relations between the vertices. Moreover,graphs are mainly used to capture relational data (social networks, knowledge bases, image entities, maps, etc.) and,thus, are important for reasoning. Graph Neural Networks
Graph Neural Networks where initially proposed by Scarselli et al. [1] and they are a familyof neural networks (universal approximators composed by trainable linear and non-linear transformations[2]) that mapgraphs to vector representations. This type of neural networks make use of generalized counterparts of traditionalneural network functions to work with data with an arbitrary number of nodes and edges.Recently, Gilmer et al. [3] proposed the Message Passing Neural Network (MPNN) scheme, which is invari-ant to graph isomorphism and makes use of messages to propagate information across adjacent nodes. The MPNNarchitecture is comprised of 2 stages: the message passing phase and the readout phase. Formally: a r X i v : . [ c s . L G ] S e p PREPRINT x t +1 ν = (cid:88) w ∈ N ( ν ) M t ( h tν , h tw , e νw ) (1) h t +1 ν = U t ( h tν , m t +1 ν ) (2) (cid:98) y = R ( h tν | ν ∈ G ) (3)In the massage passing phase (Eqs.1 and 2), the network learns an embedding vector for each node of the graph viaiteratively aggregating and propagating information accross adjacent nodes. This part is done in sequential T stageswhere the network performs 2 functions at each stage: the message function (Eq.1) and the update function (Eq.2).Since graphs can have arbitrary number of nodes and edges, the resulting embedding vectors depend on the originalgraph, which creates the need for a function that summarizes the arbitrary number of embedding vectors into a fixedvector representation, which is named as readout function (Eq.3). This fixed output can be later passed to a machinelearning algorithm of choice to perform different downstream tasks such as classification, regression, clustering,etc. Readout functions.
A highly desired property for readout functions is permutation invariance since graphs areunordered in nature. Commonly used readout functions are the Mean( · ), Max( · ) and Sum( · ) functions. However, thesefunctions present a problem of information loss at the readout phase. Moreover, the choice of the readout function issensitive to the nature of the dataset, which makes it difficult to choose the optimal function for a specific task and itoften leads to a suboptimal performance in many cases.Some attempts to mitigate the information loss at the readout phase have been performed, consisting mainly in 3different types of techniques: • Fuzzy histograms, which are based on aggregating the arbitrary vector embeddings by calculating the member-ship probabilities to several fuzzy bins, were implemented by [4] • Hierarchical clustering, which exploit hierarchical structures present in graphs, was used in [5] . • Multi-Layer Perceptrons, which can also be interpreted as weighted sums of node features and are highlyexpressive but at the cost of not being order invariant, were reported in [6] .However, the hitherto attempted methods exploit domain-specific graph attributes (i.g. hierarchical structures, orderdependency, ...) and have not been compared to modern state of the art architectures which make use of simple Mean( · ),Max( · ) or Sum( · ) functions and exhibit better performance. Contributions
The main contribution of this work is a novel and readout layer composed by generalized mean-max-sumfunction families which substantially improves performance of state of the art GNN architectures and that can be usedas a drop-in replacement for existing readout functions with very limited computational overhead. We achieve a state ofthe art result (improvement by 2 % ) in the ogb-molhiv dataset [7] . Generalized Functions . A generalized function f θ ( . ) is the one parametrized by a set of values θ for which exists aset of special cases S θ for which such function reverts to another, more simple function. Generalized functions areuseful for defining continuous and intermediate transitions between different existing functions, thus becoming moreexpressive than their special counterparts. Given the increased expressiveness, we hypothesize that a generalizedfunction can help to mitigate the information loss in the readout phase which Message Passing Neural Networks sufferfrom.We define a Generalized Readout layer parametrized by 2 values ( β and p ) that is able to represent the ma-jority of the existing and commonly used readout functions (mean, max, sum, min, ...) while maintaining permutationand ordering invariance. Our generalized layer reverts to commonly used functions under certain simple parametersettings, and it is also fully differentiable which allows for the training of the parameters via gradient-based algorithmssuch as the popular backpropagation. Proposition 1 (Generalized Mean-Max-Sum Aggregators). A generalized Mean-Max-Sum aggregation func-tion is that for which there exist at least 3 combinations of a set of parameters θ i such that: lim θ → θ f θ = M ean ( . ) and lim θ → θ f θ = M ax ( . ) and lim θ → θ f θ = Sum ( . ) . 2 PREPRINT
In order to cover the 3 most popular readout phase functions (mean, max and sum), we design 2 mean-max-sumgeneralized function families (proposition 1). Inspired by [8] , we use two different families of generalizedmean-max-sum functions: a softmax-based aggregation and a power mean-based aggregation. Both function familiesare parametrized by two parameters ( β and p ) which are optionally learnable and can be set to the initial value of choice.Both families return graph-level-outputs by transforming node features based on a Generalized Aggr-Mean-Max function, so that for a single graph G i its output is computed depending on the family of transformations by eitherEq.4 for the Sof tmax family or by Eq.5 for the
P owerM ean family. Let N i be the number of nodes in the graph G i , r i the output of the readout function, and x n the embedding of the n th node from the Message Passing phase: r i = N i β ∗ ( N i − N i (cid:88) n =1 softmax ( x n ∗ p ) ∗ x n (4) r i = (cid:32)
11 + β ∗ ( N i − N i (cid:88) n =1 x pn (cid:33) /p (5)Here we detail some of the special cases of the 2 function families used:The softmax-based aggregation family reverts to: • the Mean function when β = 1 and p = 0 • the Max function when β = 1 and p → + ∞• the Min function when β = 1 and p → −∞• the Sum function when β = 0 and p = 0 The power mean based aggregation function family reverts to: • the Mean function when β = 1 and p = 1 • the Max function when β = 0 and p → + ∞• the Min function when β = 0 and p → −∞• the Sum function when β = 0 and p = 1 Figure 1: Landscape of generalized Mean-Max-Sum functions3
PREPRINT
The power mean based aggregation function family can also revert to different mean functions (quadratic mean,geometric mean, harmonic mean...).Another advantage of the 2 function families used is that they both keep the permutation-invariance to the ordering ofnodes and they support an arbitrary number of messages to be aggregated, two highly desired properties in readoutfunctions.
Experimental Setup . The superior expressiveness and performance of the proposed novel readout phase function istested by comparing the best-performing GCNN on a given graph property prediction dataset (ogb-molhiv [7] ) with theMean( · ) readout phase, as indicated by [9] , and our new functions under a drop-in replacement scheme.The ogb-molhiv [7, 10] dataset is a graph property prediction dataset comprised of 41,127 molecules and thegoal is a binary classification depending on their experimentally-checked ability to inhibit the HIV virus replication.The evaluation metric fro this dataset is the ROC-AUC.The computational resources used for the experiments have been: • Hardware : All the experiments in this work were carried on an Intel i7-6700-K 4.0Ghz and 16gb RAMcomputer with an Nvidia GTX 1060 6gb graphics card. • Software : In order to design and train the Deep Learning models, we employed the PyTorch [11] and PyTorchGeometric [12] python libraries.Our new technique has been implemented as a PyTorch Geometric [12] layer for ease of use. The code to reproduce theexperiments can be accessed here: https://github.com/EricAlcaide/generalized-readout-phase
Experiment 1 . In order to prove the superior performance of our novel technique, we performed a drop-in replacementwith our readout phase and the Hierarchical Inter-Message Passing GNN (HIMP) [12] original architecture for theogb-molhiv dataset [7] (which was the best-performing (which had its code in open source) in the OGB leaderboard[13] at the time of writing this paper). We trained the new model architecture for 10 independent runs consisting of100 epochs each. We used the Adam optimizer [14] with the following hyperparameters (lr=1e-4, β =0.9, β =0.999, (cid:15) =1e-8). The hyperparameter combinations that were tried were: • Softmax ( β =1, p=1) and Softmax ( β =1e-5, p=1e-5) • PowerMean ( β =1, p=1) and PowerMean ( β =1e-5, p=10)The results from the experiment are showcased in table 1 and the comparison with current SoTA is done in table 2. Thefinal metric for every configuration for each of the 10 independent runs can be found, alongside with the mean ± std.For comparison with previous models, we use the best-performing model from table 1.Table 1: Results from Experiment 1 (ROC-AUC mean ± std)Our resultsSoftmax Softmax PowerMean PowerMean ± ± ± ± PREPRINT
Table 2: Comparison between our model and previous SoTA (ROC-AUC mean ± std) HIMP+GenReadout (Ours)
GraphNorm [15] HIMP[9] DeeperGCN[8] GIN+VirtualNodes[13] ± ± ± ± ± Our results from Experiment 1 represent a new state of the art result in the ogb-molhiv dataset [7] with an increaseof 5.4 % over the next top-performing model. Remarkably, all the different hyperparameter combinations triedexhibited better performance when compared to previous state of the art architectures. We attribute this to the superiorexpressiveness of our new function, which might help to mitigate the information that usually takes place at the readoutphase.We encourage the adoption of this new framework of generalized readout functions for graph property pre-diction, especially in different GNN architectures and bigger datasets than ogb-molhiv [7], since its superiorperformance has been proven in a small dataset probably due the fact that this novel mitigates the information losstypically occurring at the Readout phase in MPNN.Due to computational limitations, it has not been possible to test this novel technique on bigger and perhapsmore challenging datasets and other architectures which recently exhibited good performance. Thus, its superiorperformance on bigger datasets can not be guaranteed, but only expected by logical extrapolation. This is likely to beaddressed in future work. In this work, we have introduced a generalized readout function, which can revert to many already popular functionsunder certain parameter settings, to address the information loss problem that typically occur in Graph Neural Networksduring the readout phase. We have successfully proven its superior expressiveness & performance by achieving newSoTA results in a standardized graph property prediction task from the Open Graph Benchmark [7] . The scalability ofthe proposed technique to bigger datasets and different model architectures is to be addressed in future work.
References [1] F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner, and G. Monfardini. The graph neural network model.
IEEETransactions on Neural Networks , 20(1):61–80, 2009.[2] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. Multilayer feedforward networks are universal approxi-mators.
Neural Networks , 2(5):359 – 366, 1989.[3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. Neural message passingfor quantum chemistry, 2017.[4] Steven Kearnes, Kevin McCloskey, Marc Berndl, Vijay Pande, and Patrick Riley. Molecular graph convolutions:moving beyond fingerprints.
Journal of Computer-Aided Molecular Design , 30(8):595–608, Aug 2016.[5] Mikael Henaff, Joan Bruna, and Yann LeCun. Deep convolutional networks on graph-structured data, 2015.[6] Joan Bruna, Wojciech Zaremba, Arthur Szlam, and Yann LeCun. Spectral networks and locally connectednetworks on graphs, 2013.[7] Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and JureLeskovec. Open graph benchmark: Datasets for machine learning on graphs. arXiv preprint arXiv:2005.00687 ,2020.[8] Guohao Li, Chenxin Xiong, Ali Thabet, and Bernard Ghanem. Deepergcn: All you need to train deeper gcns,2020.[9] M. Fey, J. G. Yuen, and F. Weichert. Hierarchical inter-message passing for learning on molecular graphs. In
ICML Graph Representation Learning and Beyond (GRL+) Workhop , 2020.[10] Zhenqin Wu, Bharath Ramsundar, Evan N. Feinberg, J. Gomes, Caleb Geniesse, Aneesh S. Pappu, Karl Leswing,and V. Pande. Moleculenet: a benchmark for molecular machine learning † † electronic supplementary information(esi) available. see doi: 10.1039/c7sc02664a.
Chemical Science , 9:513 – 530, 2018.5
PREPRINT [11] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen,Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito,Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala.Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelz-imer, F. d’Alch e-Buc, E. Fox, and R. Garnett, editors,
Advances in Neural Information Processing Systems 32 ,pages 8024–8035. Curran Associates, Inc., 2019.[12] Matthias Fey and Jan E. Lenssen. Fast graph representation learning with PyTorch Geometric. In
ICLR Workshopon Representation Learning on Graphs and Manifolds , 2019.[13] Leaderboards for graph property prediction. https://ogb.stanford.edu/docs/leader_graphprop/https://ogb.stanford.edu/docs/leader_graphprop/