Learning Dynamic BERT via Trainable Gate Variables and a Bi-modal Regularizer
LLearning Dynamic BERT viaTrainable Gate Variables and a Bi-modal Regularizer
Seohyeong Jeong
Seoul National University seo [email protected]
Nojun Kwak
Seoul National University [email protected]
Abstract
The BERT model has shown significant suc-cess on various natural language processingtasks. However, due to the heavy model sizeand high computational cost, the model suf-fers from high latency, which is fatal to itsdeployments on resource-limited devices. Totackle this problem, we propose a dynamic in-ference method on BERT via trainable gatevariables applied on input tokens and a reg-ularizer that has a bi-modal property. Ourmethod shows reduced computational cost onthe GLUE dataset with a minimal performancedrop. Moreover, the model adjusts with a trade-off between performance and computationalcost with the user-specified hyperparameter.
BERT (Devlin et al., 2018), the large-scale pre-trained language model, has shown significant im-provements in natural language processing tasks(Dai and Le, 2015; Rajpurkar et al., 2016; McCannet al., 2017; Peters et al., 2018; Howard and Ruder,2018). However, the model suffers from the heavymodel size and high computational cost, whichhinders the model to be applicable in real-time sce-narios on resource-limited devices.Khetan and Karnin (2020) has shown that using a“tall and narrow” architecture provides better perfor-mance than a “wide and shallow” architecture whenobtaining a computationally lighter model. Inspiredby this finding, we propose a task-agnostic methodto dynamically learns masks on the input word vec-tors for each BERT layer during fine-tuning. To dothis, we propose a gating module, which we calla mask matching module, that sorts and matcheseach of the input tokens to corresponding learnedmasks. Note that we use the word “gate” and “mask”interchangeably. Inspired by Srinivas et al. (2017),we train our model with an additional regularizerthat has bi - modal property on top of the l -variant (a) Two-stage method framework(b) One-stage method framework Figure 1: The overview of the two-stage method andthe one-stage method for dynamic inference models. regularizer, that we suggest in this work, and anoriginal task loss. Using a bi - modal regularizerallows the model to learn a downstream task andsearch the model architecture simultaneously, with-out requiring any further fine-tuning stage.In this paper, we conduct experiments withBERT-base on the GLUE (Wang et al., 2018)dataset and show that the mask matching moduleand the bi - modal regularizer enable the model tosearch the architecture and fine-tune on a down-stream dataset simultaneously. Compared to pre-vious works in compressing or accelerating theinference time of BERT (Sanh et al., 2019; Jiaoet al., 2019; Sun et al., 2019; Liu et al., 2020), ourmethod possesses three main differences. First ofall, our method allows task-agnostic dynamic in-ference rather than a single reduced-sized model.Secondly, our method does not require any addi-tional stage of fine-tuning or knowledge distillation a r X i v : . [ c s . C L ] F e b igure 2: Comparison of an original encoder block of BERT and our model with the mask matching module. (KD) (Hinton et al., 2015) and lastly, our methodprovides a hyperparameter that can be specified bya user to control a trade-off between the computa-tional complexity and the performance. There have been numerous works to compress andaccelerate the inference of BERT. Adopting KD,Sanh et al. (2019) attempts to distill heavy teachermodels into a lighter student model. Pruning thepre-trained model is another method to handle theissue of heavy model sizes and high computationalcost (Michel et al., 2019; Gordon et al., 2020). Saj-jad et al. (2020) prunes BERT by dropping unneces-sary blocks and Goyal et al. (2020) does it by drop-ping semantically redundant word vectors. Someother works have introduced dynamic inference toaccelerate the inference speed on BERT. Xin et al.(2020) allows early exiting and Liu et al. (2020)adjusts the number of executed blocks dynamically.Our work is mainly inspired by Goyal et al.(2020) and we integrate dynamic inference to thesequence pruning. The main difference of ourmethod compared to theirs and other pruning meth-ods (Michel et al., 2019; Gordon et al., 2020; Saj-jad et al., 2020) is that our model allows a task-agnostic dynamic inference without the additionalrequirement of fine-tuning after a model architec-ture search, as illustrated in Figure 1 (b). Thereexist other works Hou et al. (2020); Goyal et al.(2020); Liu et al. (2020); Fan et al. (2019); Elbayadet al. (2019) to dynamically adjust the size andlatency of the language models. However, these ap- proaches either works in a two-stage setting wherefurther fine-tuning or knowledge distillation is re-quired, as shown in Figure 1 (a) or consider a depth-wise compression rather than a width-wise com-pression. We experimentally show that the compu-tational cost can be reduced with minimal perfor-mance drop on GLUE (Wang et al., 2018).
In this section, we introduce our proposed methodthat mainly consists of a mask matching moduleand an additional regularizer to induce polarizationon mask variables.
As presented in Figure 2 (a), the original encoderblock of BERT consists of multi-head attention andfeed-forward networks. The intuition behind themask matching module is to filter out input tokensthat do not contribute as much in solving a giventask so that the model can benefit from the reducedcomputational burden during the process of multi-head attention. Since the multi-head attention onsequences of length, L is O ( L ) in computationalcomplexity, we expect to reduce this cost by mask-ing out unnecessary tokens for each encoder block.In order to learn important tokens in the trainingprocess, we introduce the mask matching modulewhich is placed before the original encoder blockof BERT, as shown in Figure 2 (b). Figure 2 (c)shows the detailed process of the mask matchingmodule. The superscript l represents the l th block,which we omit from the following description inhis section. The module consists of sorting inputtokens according to importance scores and match-ing each input token to a mask, and thresholdingthe computed tokens with a certain value.We first compute the importance score, s ∈ R I ,of each token in the input sequence as s i = (cid:80) Jj =1 | X ij | , where X ∈ [ I, J ] is the matrix rep-resentation of the input, with I being the length ofthe input sequence and J being the size of the hid-den dimension. As each token has corresponding s i , we sort the input matrix, X , input sequence-wiseaccording to the importance score of each token.Then, sorted input matrix and sorted masks are mul-tiplied element-wise to perform mask matching toobtain a masked matched input matrix, S ∈ [ I, J ] : S = sort ( X ) (cid:12) expand ( sort ( σ ( m ))) (1)where σ () is a sigmoid function and m ∈ R I is aparameter. Note that since σ ( m ) ∈ R I and X ∈ [ I, J ] , we expand σ ( m ) to match the shape of X bymultiplying it J times and stacking them.Then, we introduce a thresholding scheme onmasked tokens as follows: th ( S i, J ) = (cid:40) S i, J σ ( m i ) ≥ α σ ( m i ) < α (2)where α is a hyperparamter and m i is a learnedmask value for i th token in the input. The thres-holed output is unsorted into the original sequenceof input tokens and passed to the consecutive en-coder block as an input. The final output of themasked matching module is written as follows: X m = unsort ( th ( S )) (3) Traditional l and l regularizers do not guaranteewell-polarized values for a gate(mask) variable. Inorder to induce polarization on our masks, we uti-lize a bi-modal regularizer proposed by (Murrayand Ng, 2010; Srinivas et al., 2017) to learn binaryvalues for parameters. (Srinivas et al., 2017) usedan overall regularizer which is a combination ofthe bi − modal regularizer and a traditional l or l regularizer. In this work, we use a customized regu-larizer, which is a variant of l , denoted as l filter , todynamically adjust the level of sparsity accordingto the user-specified hyperparameter. l filter = 1 L L (cid:88) n =1 | v filter,l | , v filter = w (cid:12) ( v masks − v user ) . (4) v masks , v user , w ∈ R L are filtering weights, mass ofmasks, and the user specified mass of masks with L being the number of blocks in a model. v masks,l = I (cid:88) i =1 σ ( m li ) , v user,l = I × L × γ, w l = 1 . − { I (cid:88) i =1 σ ( m li ) } /I. (5)where I is the length of the input token sequenceand ≤ γ ≤ is a hyperparameter to enforce theuser-specified level of filtering tokens in the model.Then, the polarization regularizer is written as alinear combination of l filter and l bi − modal , whichhas a form of w × (1 − w ) , as follows: L polar = λ filter ∗ l filter + λ bi ∗ L (cid:88) l =1 I (cid:88) i =1 σ ( m li )(1 − σ ( m li )) . (6)Our total objective function is stated as follows: L total = L task + L polar (7) L task is the loss for a downstream task. We showthe effect of the bi-modal regularizer in Sec. 4.3. We evaluate the proposed method on eight datasetsin GLUE (Wang et al., 2018) benchmark.
We fine-tune the pre-trained BERT-base model on8 datasets in the GLUE benchmark dataset for 3epochs with a batch size of 128. The hidden dimen-sion is set to J = 768 and the length of the inputtoken sequence is set to I = 128 . For the rest of thedetails, we follow the original settings of BERT.We set α = 0 . , λ filter = 0 . , and λ bi = 2 . .We use the separate Adam (Kingma and Ba, 2014)optimizer for training mask variables. The Adamoptimizer for mask variables are set with initiallearning rate of . with two momentum parame-ters β = 0 . and β = 0 . , and (cid:15) = 1 × − .Mask variables are initialized with random valuesfrom a uniform distribution on the interval [0, 1).We do not introduce the mask variables for the veryfirst block in the model. Additionally, we never fil-ter (do not mask) the first token of each input, thespecial token [CLS] . Introducing mask variablesresults in “the length of tokens × the number ofblocks” additional number of parameters. odels GLUE-test MNLI-(m/mm) QNLI QQP RTE SST-2 MRPC CoLA STS-BBERT-base (Devlin et al., 2018) 84.6 / 83.4 90.5 71.2 66.4 93.5 88.9 52.1 85.8BERT-base-ours 84.5 / 83.7 90.7 71.8 62.4 93.9 83.7 51.2 78.9(FLOPs) 10872M 10872M 10872M 10872M 10872M 10872M 10872M 10872M
Ours ( γ = 0 . / (FLOPs) 3357M 3915M 3766M 4629M 3887M 4417M 2629M 3371M (3.23 × ) (2.77 × ) (2.88 × ) (2.35 × ) (2.80 × ) (2.46 × ) (4.13 × ) (3.23 × ) Table 1: Comparison of GLUE test results, scored by the official evaluation server. BERT-ours is our implementa-tion of the baseline model, BERT. Performances for
Ours is reported with γ = 0 . . Last row shows the computa-tional improvement compared to the FLOPs of original BERT-base. Models GLUE-eval
MNLI-(m/mm) QNLI SST-2BERT-base-ours 84.3 / 84.9 91.7 92.5(FLOPs) 10872M 10872M 10872MOurs ( γ = 0 . ( γ = 0 . Ours ( γ = 0 . ( γ = 0 . ( γ = 0 . ( γ = 0 . ( γ = 0 . ( γ = 0 . ( γ = 0 . Table 2: Performances and FLOPs on GLUE evaluationset with different values of γ . We compare our model with the BERT-base base-line. Table 1 summarizes the results of these mod-els. Performances on the first row are taken fromDevlin et al. (2018) and we show performanceswith our implementation on the second row. Thelast row shows the improvement compared to theFLOPs of the baseline model. It shows that ourdynamic inference method with γ = 0 . showsminimal degradation on GLUE datasets with anaverage of 3 times fewer FLOPs. Furthermore, ourmodel works in a task-agnostic manner and outputsthe optimal architecture for each given downstreamdataset, instead of a single reduced-sized model.Table 2 shows that our model is capable of dy-namically adjusting the computational cost witha trade-off between FLOPs and performance. It Models GLUE-eval
MNLI-(m/mm) QNLI SST-2Ours ( λ bi = 2 . ( λ bi = 0 . Table 3: Ablation study of the bi-modal regularizer onthe GLUE evaluation set. shows that the hyperparameter, γ , works properlyshowing proportional FLOPs to its given value.The result presents generally a consistent trade-offbetween FLOPs and performance. To analyze the effect of the bi-modal regularizer,we conduct an ablation study by removing it fromthe training process. Table 3 shows the effect of the bi-modal regularizer and we claim that employingthis regularizer during the training process plays ahuge role in learning to perform well on a down-stream task as well as searching the optimal modelstructure with the help of well-polarized mask vari-ables. Further analysis on the behavior of maskvariables with and without the bi-modal regularizeris shown in Appendix C.
In this work, we explore the task-agnostic dynamicinference method on BERT that works by mask-ing out the input sequence for each block. To dothis, we propose a mask matching module and avariant of l regularizer, which we call l filter . Ourmethod yields various levels of models with dif-ferent performance and computational complexity,depending on the hyperparameter value that theuser inputs. Conducting experiments on the GLUEdataset, our method shows that BERT, used withour method, can enjoy lighter computation withminimal performance degradation. eferences Andrew M Dai and Quoc V Le. 2015. Semi-supervisedsequence learning. In
Advances in neural informa-tion processing systems , pages 3079–3087.Jacob Devlin, Ming-Wei Chang, Kenton Lee, andKristina Toutanova. 2018. Bert: Pre-training of deepbidirectional transformers for language understand-ing. arXiv preprint arXiv:1810.04805 .Maha Elbayad, Jiatao Gu, Edouard Grave, and MichaelAuli. 2019. Depth-adaptive transformer. arXivpreprint arXiv:1910.10073 .Angela Fan, Edouard Grave, and Armand Joulin. 2019.Reducing transformer depth on demand with struc-tured dropout. arXiv preprint arXiv:1909.11556 .Mitchell A Gordon, Kevin Duh, and Nicholas Andrews.2020. Compressing bert: Studying the effects ofweight pruning on transfer learning. arXiv preprintarXiv:2002.08307 .Saurabh Goyal, Anamitra Roy Choudhury, SaurabhRaje, Venkatesan Chakaravarthy, Yogish Sabharwal,and Ashish Verma. 2020. Power-bert: Accelerat-ing bert inference via progressive word-vector elim-ination. In
International Conference on MachineLearning , pages 3690–3699. PMLR.Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. 2015.Distilling the knowledge in a neural network. arXivpreprint arXiv:1503.02531 .Lu Hou, Lifeng Shang, Xin Jiang, and Qun Liu. 2020.Dynabert: Dynamic bert with adaptive width anddepth. arXiv preprint arXiv:2004.04037 .Jeremy Howard and Sebastian Ruder. 2018. Univer-sal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146 .Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, XiaoChen, Linlin Li, Fang Wang, and Qun Liu. 2019.Tinybert: Distilling bert for natural language under-standing. arXiv preprint arXiv:1909.10351 .Ashish Khetan and Zohar Karnin. 2020. schu-bert: Optimizing elements of bert. arXiv preprintarXiv:2005.06628 .Diederik P Kingma and Jimmy Ba. 2014. Adam: Amethod for stochastic optimization. arXiv preprintarXiv:1412.6980 .Weijie Liu, Peng Zhou, Zhe Zhao, Zhiruo Wang,Haotang Deng, and Qi Ju. 2020. Fastbert: a self-distilling bert with adaptive inference time. arXivpreprint arXiv:2004.02178 .Bryan McCann, James Bradbury, Caiming Xiong, andRichard Socher. 2017. Learned in translation: Con-textualized word vectors. In
Advances in Neural In-formation Processing Systems , pages 6294–6305. Paul Michel, Omer Levy, and Graham Neubig. 2019.Are sixteen heads really better than one? arXivpreprint arXiv:1905.10650 .Walter Murray and Kien-Ming Ng. 2010. An algo-rithm for nonlinear optimization problems with bi-nary variables.
Computational optimization and ap-plications , 47(2):257–288.Matthew E Peters, Mark Neumann, Mohit Iyyer, MattGardner, Christopher Clark, Kenton Lee, and LukeZettlemoyer. 2018. Deep contextualized word repre-sentations. arXiv preprint arXiv:1802.05365 .Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, andPercy Liang. 2016. Squad: 100,000+ questions formachine comprehension of text. arXiv preprintarXiv:1606.05250 .Hassan Sajjad, Fahim Dalvi, Nadir Durrani, andPreslav Nakov. 2020. Poor man’s bert: Smallerand faster transformer models. arXiv preprintarXiv:2004.03844 .Victor Sanh, Lysandre Debut, Julien Chaumond, andThomas Wolf. 2019. Distilbert, a distilled versionof bert: smaller, faster, cheaper and lighter. arXivpreprint arXiv:1910.01108 .Suraj Srinivas, Akshayvarun Subramanya, andR Venkatesh Babu. 2017. Training sparse neuralnetworks. In
Proceedings of the IEEE conference oncomputer vision and pattern recognition workshops ,pages 138–145.Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. 2019.Patient knowledge distillation for bert model com-pression. arXiv preprint arXiv:1908.09355 .Alex Wang, Amanpreet Singh, Julian Michael, FelixHill, Omer Levy, and Samuel R Bowman. 2018.Glue: A multi-task benchmark and analysis platformfor natural language understanding. arXiv preprintarXiv:1804.07461 .Ji Xin, Raphael Tang, Jaejun Lee, Yaoliang Yu, andJimmy Lin. 2020. Deebert: Dynamic early exit-ing for accelerating bert inference. arXiv preprintarXiv:2004.12993 . A Additional Details
A.1 Experimental Details
An output of multi-head attention, feed-forwardnetwork, and layer normalization from Figure 2further needs to be masked since these computa-tions contain bias terms. Our goal is to mask outthe input matrix of each encoder block token-wise.Therefore, we apply hard masking on input matrixdimensions that are masked out by the mask match-ing module after computations mentioned above. igure 3: Histograms of σ ( m ) for each encoder block without and with the l bi − modal regularizer. For MNLI-(m/mm), QNLI, and SST-2 datasets, we show the histogram of σ ( m ) values without the l bi − modal regularizeron the left and with the regularizer on the right. The performance on each scenario is written on the left topcorner of each figure. It shows that the l bi − modal regularizer not only participates in training mask variables in awell-polarized manner but also plays an important role in learning to perform well on a given task. A.2 Reported Measures for GLUE
QQP and MRPC are reported with F1 scores, STS-B is reported with Spearman correlations and othertasks are reported with accuracy.
B Interpretation of l filter Regularizer
We propose a variant of l regularizer, called l filter ,as shown in Eq. 4 and 5. As our l filter can comeacross somewhat heuristic, we explain the intuitionand interpretation behind the regularizer. Let’s con-sider an extreme case of (cid:80) Ii =1 σ ( m ai ) = I and (cid:80) Ii =1 σ ( m bi ) = 0 . Then, from the last line of Eq.5, w a = 0 . and w a = 1 . . This means that inputtokens for the a th block are required more thaninput tokens for the b th block of the model, sincethe prior use more tokens (mask out less tokens).As shown in the second line of Eq. 4, w works as aweight for v masks − v user . Instead of applying sameweight for each block, we intend to apply weightsaccordingly to the number of masks used in eachblock. In other words, we wish to pose heavier loss on the b th block than the a th block of the model. C Analysis on Mask Variables
Figure 3 shows learned values for mask variablesafter the sigmoid function. Each histogram hasmask values after the sigmoid function on the x-axis. Since we conduct experiments on BERT-base,we show results for every block in the model fromthe nd block to the th block from bottom to topin the figure. It shows that the l bi − modalmodal