Adversarial Contrastive Pre-training for Protein Sequences
Matthew B. A. McDermott, Brendan Yap, Harry Hsu, Di Jin, Peter Szolovits
PProceedings of Machine Learning Research 1–7
Adversarial Contrastive Pre-training for Protein Sequences
Matthew B.A. McDermott [email protected]
Brendan Yap [email protected]
Tzu Ming Harry Hsu [email protected]
Di Jin [email protected]
Peter Szolovits [email protected]
CSAIL, MIT
Abstract
Recent developments in Natural Lan-guage Processing (NLP) demonstratethat large-scale, self-supervised pre-training can be extremely beneficial fordownstream tasks. These ideas havebeen adapted to other domains, includ-ing the analysis of the amino acid se-quences of proteins. However, to datemost attempts on protein sequences relyon direct masked language model stylepre-training. In this work, we designa new, adversarial pre-training methodfor proteins, extending and specializingsimilar advances in NLP. We show com-pelling results in comparison to tradi-tional MLM pre-training, though fur-ther development is needed to ensurethe gains are worth the significant com-putational cost.
Keywords:
Protein pre-training, adver-sarial methods, pre-training, contrastiveestimation, transformers
1. Introduction
Pre-training, particularly using a self-supervised masked language model (MLM)task over a large corpus, has recently emergedas a powerful tool to improve various pre-diction and generation tasks, first in naturallanguage processing (NLP) via systems likeBERT (Devlin et al., 2019), and later in otherdomains, including biomedical domains such as protein sequences (Rao et al., 2019; Al-ley et al., 2019; Conneau et al., 2019; Luet al., 2020). However, unlike NLP, wherenewer methods have brought improved per-formance, the state of the art for pre-trainingon protein sequences remains simple MLMstyle pre-training. For many tasks of interest,this offers only small benefits over prior meth-ods, or even fails to match methods based onhuman constructed, alignment based features.In this work, we design a new pre-trainingmethod for protein sequences, adapting thetraditional MLM task by replacing the ran-dom masking scheme with a fully differen-tiable, adversarial masking model trained tochoose which tokens to mask and how in or-der to make the pre-training model’s recoverytask most difficult subject to a budget con-straint. This method is a form of adversar-ial contrastive estimation (Bose et al., 2018),building on the ideas explored in ELECTRAwithin NLP (Clark et al., 2019), but fully dif-ferentiable and trained adversarially, a taskthat is made more feasible given the proteindomain’s smaller vocabulary. Using an ad-versarial mask proposal distribution has alsobeen recently explored in connection with re-ducing gradient variance (Chen et al., 2020).
1. Proteins are biological macromolecules responsiblefor the majority of functions within living cells,represented by, linear chains of “amino acids,” overwhich we perform MLM style pre-training © M.B. McDermott, B. Yap, T.M.H. Hsu, D. Jin & P. Szolovits. a r X i v : . [ c s . C L ] J a n dversarial Contrastive Pre-training for Protein Sequences We test our system on the TAPE pro-tein pre-training benchmark system (Raoet al., 2019), achieving modest improvementsover comparably trained random pre-trainingbenchmarks, though further development willbe needed to ensure this method is worth theincreased computational cost. All our codewill also be made public after review.
2. Methods
Traditionally, to learn a masked languagemodel M PT , we begin with a large, unla-beled corpus of sequences and form trainingexamples by randomly choosing a fraction(e.g., 15% in the case of BERT (Devlin et al.,2019)) of tokens to “mask.” The “masked”tokens are traditionally noised according tothree masking strategies in an 80-10-10 ra-tio: [MASK] Masking, in which the maskedtokens are replaced with a sentinel, out-of-vocabulary token [MASK] ; Keep-original mask-ing, in which the masked tokens are kept asthe original token, but the model is still scoredon its ability to form a correct language modelprediction for these tokens, and Replace Mask-ing, in which the masked tokens are replacedwith another random valid token from thevocabulary. Given a sentence from the inputcorpus masked according to this process, thepre-training model M PT is then tasked torecover the original sentence.In this work, we generalize this paradigm byintroducing a model M noiser to decide whichtokens to mask and how. M noiser is brokendown into two parts: a sequence-to-sequencemodel yielding masking probabilities, and abudgeted differentiable sampling module toactually choose the mask. The whole modelis trained end to end via an adversarial ap-proach, such that it learns to mask tokens ina way that makes it most difficult for M PT to correct. This setup is shown pictorially inFigure 1. Our full adversarial masker M noiser algorithm is given in pseudocode in the Ap- pendix, Algorithms 2,1, Section A, but wewill detail several points further here. Sequence-to-sequence model
The se-quence to sequence model M ( seq )noiser (in thiswork implemented with a gated reucrrent unit(GRU) (Cho et al., 2014) architecture), ingestsa sequence of (unmasked) amino acids x andreturns a sequence of unnormalized, multi-dimensional masking scores, s = M ( seq )noiser . s contains, per-token, a score for (1) whether ornot to mask that token at all (we will denotethis the any-mask score), and (2) a condi-tional score for each masking option, includ-ing [MASK] masking, keep-original masking,and replace masking, with s taking on a scorefor each possible token that could be used asthe replacement mask within the vocabulary(we will refer to these collectively as the mask-options scores). These are then passed intoour differentiable sampler to obtain a hard,masked sample. Differentiable Sampling Options
Weuse a slight adaption on the relaxed subsetselection algorithm of Sand and Ermon (Xieand Ermon, 2019) to transform the any-maskscores of s into a vector of normalized prob-abilities that have two important properties:first, they will average to our masking budgetconstraint, µ , and, second, they will, in ex-pectation, converge to the their sampled, one-hot approximations (i.e., these probabilitieswill be a member of a Concrete distribution).This algorithm frames the sampling problemas choosing a fixed-size subset of items fromthe valid tokens in each sentence, guided bythe provided unnormalized scores, throughrepeated application of the Gumbel-Softmax(GS) trick (Maddison et al., 2016; Jang et al.,2016). This process is outlined in pseudo-codein Appendix Algorithm 1.Next, to decide how to mask these tokens,conditioned on the token being masked, weuse GS normalization directly on the mask-options scores. Finally, to obtain differen- dversarial Contrastive Pre-training for Protein Sequences Figure 1:
Left
Traditional random masked language model pre-training.
Right
Our archi-tecture, adversarial contrastive language model pre-training. Note one can usevarious options for the
Language Model , Masking Sequence Model , or
DifferentiableSampler components.tiable one-hot outputs, we use the straight-through estimator (Bengio et al., 2013), whichsimply sets the gradient of the output hard-sample to be equal to the gradient of its sourceprobability, directly, and is especially wellsuited for probabilities from the Concrete dis-tribution which converge in expectation totheir sampled values.
Stabilizing the learning process
To sta-bilize the learning process and prevent themasker and MLM model from getting stuckin a local regime of masking, we simply add inadditionally a small fraction of random mask-ing, to enable the MLM model to constantlymake general progress, which, in turn, forcesthe masking model to constantly adapt itstask to be more difficult than random mask-ing. In order to ensure that some tokens wereconsistently masked from both varieties, weincreased the general masking rate from 15%total to 20% total, distributed as 10% randommasking and 10% adversarial masking in oursystem.
Overall Training Algorithm
We thentrain the system in a traditional adversarialpattern, alternating between several iterationsof training the masker to maximize the MLM loss, followed by several iterations of trainingthe protein encoder to minimize said loss. Weuse via iterated stochastic gradient descent(using the AdamW optimizer (Loshchilov andHutter, 2018)), with 10 iterations of noisertraining followed by 10 iterations of encodertraining; these values were determined aftera very brief search over possible alternates,using MLM training curve metrics and appar-ent learning stability to motivate that choice.Additional details about our overall trainingalgorithm are present in supplementary ma-terials Section A.
Data & Tasks
We use 4 tasks from theTAPE benchmarking datasets (Rao et al.,2019), profiled in Table 1. For full detailsof these datasets and tasks, we encouragereaders to refer to Rao et al. (2019).
Experiments
We compare our adversarialMLM model to a random MLM model (bothat 20% total masking). For our adversarialMLM, M ( seq )noiser is a 3-layer, GRU with a 1024input embedding layer and 512 output layer,and our transformer sequence model is trans-former architecture matching the size of thatprofiled in the TAPE system. Our random dversarial Contrastive Pre-training for Protein Sequences Table 1: A numerical summary of thedatasets & tasks used in thiswork (Rao et al., 2019).
Task Train Val. TestLanguage Modeling 32.2M N/A 2.1MSecondary Structure 8678 2170 513Remote Homololgy 12312 736 718Fluorescence 21446 5362 27217Stability 53679 2447 12839
MLM system encoder is an identical trans-former architecture.Models were trained on 4 NVIDIA V100GPUs, ranging in time per model but on theorder of roughly 1M encoder iterations at abatch size of 128 for approximately 80% oftraining followed by specialization at a largerbatch size of 256 using gradient accumula-tion. Neither memory saving gradients normixed-precision training were used. Duringpre-training, the system was optimized viathe AdamW (Loshchilov and Hutter, 2018)optimizer with weight decay of of e − andlearning rate of e − following the conven-tions of TAPE.Hyperparameter tuning was performed ina limited, manual manner on train-set MLMresults for the pre-training system, and usinga grid search on validation set results overbatch size, learning rate, weight decays, andearly stopping parameters for fine-tuning (us-ing fixed pre-trained models).
3. Results & Discussion
All of our results across all data settings areshown in Table 2. One can see we obtainimprovements over random pre-training on3 of 4 tasks, though in all cases changes aremild. Given the increased computational costof this style of training (approximately 2x dueto masker and encoder training), these gainsare likely not currently worth these minor Table 2: Final results for a random MLMand our adversarial MLM, reportedin accuracy / amino acid for sec-ondary structure (SS), accuracy /sequence for remote homology (RH),and Spearman correlation coefficientfor fluorescence and stability. Vari-ance is sourced over repeated FTruns only, not PT runs, as the lat-ter would be computationally in-tractable.
Task Random Adv. (Ours)S. S. . ± . . ± . R.H. . ± . . ± . Fluorescence . ± . . ± . Stability . ± . . ± . improvements. However, they do establishthat this direction of research may be a vi-able vehicle to improve protein pre-training,with further improvements. Several directionsstand out to offer such improvements. Firstly,we believe it is likely that higher-capacity nois-ing models would work better, though earlyattempts with this architecture proved toounstable to train effectively. Second, we feelwe could more effectively train this systemwith larger batch sizes (which has been shownto offer improvements in other pre-trainingcontexts) through the use of mixed-precisiontraining and gradient checkpointing. Finally,the incorporation of an importance samplingre-weighting penalty on the learning objective,rather than use of partial random masking,as in the style of (Chen et al., 2020) may offerfurther improvements here.
4. Conclusion
In this work, we design a novel adversarial,contrastive contextual embedding system forprotein sequences, attaining improvementsover comparable random pre-training runs on dversarial Contrastive Pre-training for Protein Sequences three of four tasks in the TAPE benchmark,though these improvements are minor andfurther work will be needed to ensure realizedgains are worth the increased computationalcost. References
Ethan C. Alley, Grigory Khimulya, Suro-jit Biswas, Mohammed AlQuraishi, andGeorge M. Church. Unified ratio-nal protein engineering with sequence-based deep representation learning.
Na-ture Methods , 16(12):1315–1322, Decem-ber 2019. ISSN 1548-7091, 1548-7105. doi: 10.1038/s41592-019-0598-1.URL .Yoshua Bengio, Nicholas Léonard, and AaronCourville. Estimating or Propagating Gra-dients Through Stochastic Neurons for Con-ditional Computation. arXiv:1308.3432[cs] , August 2013. URL http://arxiv.org/abs/1308.3432 . arXiv: 1308.3432.Avishek Joey Bose, Huan Ling, and Yan-shuai Cao. Adversarial Contrastive Esti-mation. In
Proceedings of the 56th An-nual Meeting of the Association for Com-putational Linguistics (Volume 1: Long Pa-pers) , pages 1021–1032, Melbourne, Aus-tralia, July 2018. Association for Com-putational Linguistics. doi: 10.18653/v1/P18-1094. URL .Liang Chen, Tianyuan Zhang, Di He,Guolin Ke, Liwei Wang, and Tie-Yan Liu.Variance-reduced language pretraining viaa mask proposal network. arXiv preprintarXiv:2008.05333 , 2020.Kyunghyun Cho, Bart van Merrienboer,Caglar Gulcehre, Dzmitry Bahdanau, FethiBougares, Holger Schwenk, and YoshuaBengio. Learning Phrase Representations using RNN Encoder-Decoder for StatisticalMachine Translation. arXiv:1406.1078 [cs,stat] , September 2014. URL http://arxiv.org/abs/1406.1078 . arXiv: 1406.1078.Kevin Clark, Minh-Thang Luong, Quoc V. Le,and Christopher D. Manning. ELECTRA:Pre-training Text Encoders as Discrimi-nators Rather Than Generators. Septem-ber 2019. URL https://openreview.net/forum?id=r1xMH1BtvB .Alexis Conneau, Kartikay Khandelwal, Na-man Goyal, Vishrav Chaudhary, GuillaumeWenzek, Francisco Guzmán, EdouardGrave, Myle Ott, Luke Zettlemoyer,and Veselin Stoyanov. UnsupervisedCross-lingual Representation Learning atScale. arXiv:1911.02116 [cs] , November2019. URL http://arxiv.org/abs/1911.02116 . arXiv: 1911.02116.Jacob Devlin, Ming-Wei Chang, Kenton Lee,and Kristina Toutanova. BERT: Pre-training of Deep Bidirectional Transform-ers for Language Understanding. In
Pro-ceedings of the 2019 Conference of theNorth American Chapter of the Associa-tion for Computational Linguistics: HumanLanguage Technologies, Volume 1 (Longand Short Papers) , pages 4171–4186, Min-neapolis, Minnesota, June 2019. Associa-tion for Computational Linguistics. doi:10.18653/v1/N19-1423. URL .Eric Jang, Shixiang Gu, and Ben Poole. Cat-egorical Reparameterization with Gumbel-Softmax. November 2016. URL https://openreview.net/forum?id=rkE3y85ee .Ilya Loshchilov and Frank Hutter. Decou-pled Weight Decay Regularization. Septem-ber 2018. URL https://openreview.net/forum?id=Bkg6RiCqY7 .Amy X Lu, Haoran Zhang, Marzyeh Ghas-semi, and Alan Moses. Self-supervised dversarial Contrastive Pre-training for Protein Sequences contrastive learning of protein representa-tions by mutual information maximization. bioRxiv , 2020.Chris J. Maddison, Andriy Mnih, andYee Whye Teh. The Concrete Distribution:A Continuous Relaxation of DiscreteRandom Variables. November 2016. URL https://openreview.net/forum?id=S1jE5L5gl¬eId=S1jE5L5gl .Roshan Rao, Nicholas Bhattacharya,Neil Thomas, Yan Duan, Peter Chen,John Canny, Pieter Abbeel, and YunSong. Evaluating Protein TransferLearning with TAPE. In H. Wallach,H. Larochelle, A. Beygelzimer, F. d’Alchè-Buc, E. Fox, and R. Garnett,editors, Advances in Neural InformationProcessing Systems 32 , pages 9689–9701. Curran Associates, Inc., 2019.URL http://papers.nips.cc/paper/9163-evaluating-protein-transfer-learning-with-tape.pdf .Sang Michael Xie and Stefano Ermon. Repa-rameterizable Subset Sampling via Con-tinuous Relaxations. pages 3919–3925,2019. URL . dversarial Contrastive Pre-training for Protein Sequences Appendix A. Full Methods
Pseudocode algorithms for the full masker, the straight through sampler, and the fulladversarial system, including masker and protein encoder, are shown in the below Algorithms:
Procedure rss_sampler( s , valid_tokens_mask , ρ, t ) Result: p Input: s : Per-element selection scores, valid_tokens : Binary mask highlighting validinputs within the batch, ρ : Desired masking fraction, t > : sampling temperaturevalue.Initialize ε ← − , y soft ← , g ← s − log( − log( Uniform(
0, 1 ) )) seq_lens ← valid_tokens_mask.sum(axis=1)subset_sizes ← round( seq_lens * ρ )subset_sizes_left ← subset_sizes.expand_as(valid_tokens_mask) while subset_sizes_left > do khot_mask ← max((1 − y soft ) * valid_tokens_mask , ε ) g += log( khot_mask )) y soft += softmax( g /t, dim = − )subset_sizes_left -= 1valid_tokens_mask *= (subset_sizes_left > 0).float() endreturn y soft Algorithm 1:
Our variable masking rate per-batch variant on the relaxed subset selectionalgorithm of (Xie and Ermon, 2019)
Input: ρ : masking rate; t : mask option temperature; mask_id : [MASK] token ID in vocabu-lary Procedure run_noiser( θ , x , valid_tokens_mask ) s ← M ( seq )noiser ( x , valid_tokens_mask ; θ ) p mask overall ← rss_sampler( s [:,:,0] , valid_tokens_mask , ρ, t ) g mask type ← s [:,:,1:] − log( − log( Uniform(
0, 1 ) )) p mask type ← softmax( g mask type /t ) ˜ x ← straight_through( x , p mask overall , p mask type ) return ˜ x Algorithm 2:
The full noising process, taking in the raw, un-masked input vector x andproducing a noised version of the input suitable for pre-training an MLM. rss_sampler is shown in Algorithm 1 and straight_through is shown in the supplementary materials,Algorithm 3 dversarial Contrastive Pre-training for Protein Sequences Result: ˜ x Input: x : one-hot encoding of the input, which can be multiplied by an embedding layer toproduce distributed embeddings of each token. p mask overall : probabilities of maskingeach token in any form; p mask type : probabilities of each kind of masking of eachtoken, or None for simple masking; mask_id : [MASK] token ID in vocabulary; v idx :the start index of the vocabulary, after skipping past all control tokens (e.g., [MASK] ). Procedure straight_through( x , p mask overall , p mask type )mask_ANY ← where( p mask overall > . , , ) x masked ← if p mask type = None then x masked += one_hot( mask_id ) + p mask overall − detach( p mask overall ) else M ← argmax_mask( p mask type ) + p mask type − detach( p mask type ) m [MASK] ← M [:,:,0] m keep ← M [:,:,1] m replace ← M [:,:,2:] x masked += one_hot( mask_id ) * m [MASK] x masked += x * m keep x masked [:,:, v idx :] += m replace end ˜ x ← (1 − mask_ANY ) x + ( mask_ANY ) x masked return ˜ x Algorithm 3:
Straight through sampler, accounting for all three modes of masking.
Procedure run_encoder( θ , φ , x , valid_tokens_mask ) ˜ x ← run_noiser( θ , x , valid_tokens_mask ) p reconst. ← M PT (˜ x , valid_tokens_mask ; φ PT ) return L ( p reconst. , x , valid_tokens_mask ) Procedure update_noiser( θ , φ , x , valid_tokens_mask ) return SGD( - run_encoder( θ , φ , x , valid_tokens_mask )) Procedure update_encoder( θ , φ , x , valid_tokens_mask ) return SGD(run_encoder( θ , φ , x , valid_tokens_mask )) Algorithm 4:
Noiser & Encoder Update Steps. In practice we use more the more advancedAdamW SGD variant. dversarial Contrastive Pre-training for Protein Sequences Result: θ ∗ , φ ∗ PT Input: n noiser : Input: n encoder : θ (0) , φ (0)PT randomly Initialize i ← , mode ← PRE_TRAINING_NOISING while
Not Converged do x , valid_tokens_mask ← get_batch( i ) if mode = PRE_TRAINING_NOISING then θ ( i +1) ← update_noiser( θ ( i ) , φ ( i )PT , x , valid_tokens_mask ) if i mod n noiser = 0 then mode ← PRE_TRAINING_ENCODING endelse φ ( i +1)PT ← update_encoder( θ ( i +1) , φ ( i )PT , x , valid_tokens_mask ) if i mod n encoder = 0 then mode ← PRE_TRAINING_NOISING endend i ← i + 1 endreturn θ ( i ) , φ ( i )PT Algorithm 5: