CoT: Cooperative Training for Generative Modeling of Discrete Data
Sidi Lu, Lantao Yu, Siyuan Feng, Yaoming Zhu, Weinan Zhang, Yong Yu
CCoT: Cooperative Training for Generative Modeling of Discrete Data
Sidi Lu Lantao Yu Siyuan Feng Yaoming Zhu Weinan Zhang Yong Yu Abstract
In this paper, we study the generative models ofsequential discrete data. To tackle the exposurebias problem inherent in maximum likelihood es-timation (MLE), generative adversarial networks(GANs) are introduced to penalize the unrealis-tic generated samples. To exploit the supervi-sion signal from the discriminator, most previ-ous models leverage REINFORCE to address thenon-differentiable problem of sequential discretedata. However, because of the unstable propertyof the training signal during the dynamic processof adversarial training, the effectiveness of RE-INFORCE, in this case, is hardly guaranteed. Todeal with such a problem, we propose a novel ap-proach called Cooperative Training (CoT) to im-prove the training of sequence generative models.CoT transforms the min-max game of GANs intoa joint maximization framework and manages toexplicitly estimate and optimize Jensen-Shannondivergence. Moreover, CoT works without thenecessity of pre-training via MLE, which is cru-cial to the success of previous methods. In theexperiments, compared to existing state-of-the-artmethods, CoT shows superior or at least competi-tive performance on sample quality, diversity, aswell as training stability.
1. Introduction
Generative modeling is essential in many scenarios, in-cluding continuous data modeling ( e.g. image generation(Goodfellow et al., 2014; Arjovsky et al., 2017), stylization(Ulyanov et al., 2016), semi-supervised classification (Rad-ford et al., 2015)) and sequential discrete data modeling,typically neural text generation (Bahdanau et al., 2014; Yuet al., 2017; Lu et al., 2018). APEX Lab, Shanghai Jiao Tong University, Shanghai,China Stanford University, California, USA. Correspondenceto: Sidi Lu < steve [email protected] > , Weinan Zhang < [email protected] > . Proceedings of the th International Conference on MachineLearning , Long Beach, California, PMLR 97, 2019. Copyright2019 by the author(s).
For sequential discrete data with tractable density like nat-ural language, generative models are predominantly opti-mized through Maximum Likelihood Estimation (MLE),inevitably introducing exposure bias (Ranzato et al., 2015),which results in that given a finite set of observations, theoptimal parameters of the model trained via MLE do notcorrespond to the ones yielding the optimal generative qual-ity. Specifically, the model is trained on the data distributionof inputs and tested on a different distribution of inputs,namely, the learned distribution. This discrepancy impliesthat in the training stage, the model is never exposed to itsown errors and thus in the test stage, the errors made alongthe way will quickly accumulate.On the other hand, for general generative modeling tasks,an effective framework, named Generative Adversarial Net-work (GAN) (Goodfellow et al., 2014), was proposed totrain an implicit density model for continuous data. GANintroduces a discriminator D φ parametrized by φ to distin-guish the generated samples from the real ones. As is provedby Goodfellow et al. (2014), GAN essentially optimizes anapproximately estimated Jensen-Shannon divergence (JSD)between the currently learned distribution and the targetdistribution. GAN shows promising results in many unsu-pervised and semi-supervised learning tasks. The successof GAN brings the naissance of a new paradigm of deepgenerative models, i.e. adversarial networks.However, since the gradient computation requires back-propagation through the generator’s output, i.e. the data,GAN can only model the distribution of continuous vari-ables, making it non-applicable for generating discrete se-quences like natural language. Researchers then proposedSequence Generative Adversarial Network (SeqGAN) (Yuet al., 2017), which uses a model-free policy gradient algo-rithm to optimize the original GAN objective. With Seq-GAN, the expected JSD between current and target discretedata distribution is minimized if the training is perfect. Seq-GAN shows observable improvements in many tasks. Sincethen, many variants of SeqGAN have been proposed to im-prove its performance. Nonetheless, SeqGAN is not an idealalgorithm for this problem, and current algorithms based onit cannot show stable, reliable and observable improvementsthat covers all scenarios, according to a previous survey (Luet al., 2018). The detailed reasons will be discussed in detailin Section 2. a r X i v : . [ c s . L G ] M a y oT: Cooperative Training for Generative Modeling of Discrete Data In this paper, we propose Cooperative Training (CoT), anovel algorithm for training likelihood-based generativemodels on discrete data by directly optimizing a well-estimated Jensen-Shannon divergence. CoT coordinatelytrains a generative module G , and an auxiliary predictivemodule M , called mediator , for guiding G in a cooperativefashion. For theoretical soundness, we derive the proposedalgorithm directly from the definition of JSD. We furtherempirically and theoretically demonstrate the superiority ofour algorithm over many strong baselines in terms of gener-ative performance, generalization ability and computationalperformance in both synthetic and real-world scenarios.
2. Background
Notations. P denotes the target data distribution. θ denotesthe parameters of the generative module G . φ denotes theparameters of the auxiliary predictive mediator module M .Any symbol with subscript g and m stands for that of thegenerator and mediator, respectively. s stands for a completesample from the training dataset or a generated completesequence, depending on the specific context. s t means the t -length prefix of the original sequence, i.e. an incompletesequence of length t . x denotes a token, and x t stands fora token that appears in the t -th place of a sequence. Thus s t = [ x , x , x , . . . , x t − ] while the initial case s is ∅ . Maximum likelihood estimation is equivalent to minimizingthe KL divergence using the samples from the real distribu-tion: min θ E s ∼ p data [ − log G θ ( s )] , (1)where G θ ( s ) is the estimated probability of s by G θ and p data is the underlying real distribution. Limitations of MLE.
MLE is essentially equivalent tooptimizing a directed KullbackLeibler (KL) divergence be-tween the target distribution p data and the currently learneddistribution G , denoted as KL ( p data (cid:107) G ) . However, sinceKL divergence is asymmetric, given finite observations thistarget is actually not ideal. As stated in Arjovsky & Bottou(2017), MLE tries to minimize KL ( p data (cid:107) G ) = (cid:88) s p data ( s ) log p data ( s ) G ( s ) . (2) • When p data ( s ) > and G ( s ) → , the KL divergencegrows to infinity, which means MLE assigns an ex-tremely high cost to the “mode dropping” scenarios,where the generator fails to cover some parts of thedata. • When G ( s ) > and p data ( s ) → , the KL divergenceshrinks to 0, which means MLE assigns an extremely low cost to the scenarios, where the model generatessome samples that do not locate on the data distribu-tion.Likewise, optimizing KL ( G (cid:107) p data ) will lead to exactly thereversed problems of the two situations. An ideal solutionis to optimize a symmetrized and smoothed version ofKL divergence, i.e. the Jensen-Shannon divergence (JSD),which is defined as JSD ( p data (cid:107) G ) = 12 (cid:0) KL ( p data (cid:107) M ) + KL ( G (cid:107) M ) (cid:1) , (3)where M = ( p data + G ) . However, directly optimizingJSD is conventionally considered as an intractable problem.JSD cannot be directly evaluated and optimized since theequally interpolated distribution M is usually considered tobe unconstructible, as we only have access to the learnedmodel G instead of P . SeqGAN incorporates two modules, i.e. the generator anddiscriminator, parametrized by θ and φ respectively, as inthe settings of GAN. By alternatively training these twomodules, SeqGAN optimizes such an adversarial target: min θ max φ E s ∼ p data [log( D φ ( s ))]+ E s ∼ G θ [log(1 − D φ ( s ))] . (4)The objectives of generator G θ and discriminator D φ inSeqGAN can be formulated as:Generator: min θ − E s ∼ G θ (cid:104) n (cid:88) t =1 Q t ( s t , x t ) · log G θ ( x t | s t ) (cid:105) (5)Discriminator: max φ E s ∼ p data [log( D φ ( s ))] + E s ∼ G θ [log(1 − D φ ( s ))] , (6)where s ∼ G θ = [ x , ..., x n ] denotes a complete sequencesampled from the generator and the actually implementedaction value Q t ( s t , x t ) = E s ∼ G θ ( ·| s t +1 ) [ D φ ( s )] is the ex-pectation of the discriminator’s evaluation on the completedsequences sampled from the prefix s t +1 = [ s t , x t ] , whichcan be approximated via Monte Carlo search. Limitations of SeqGAN & its Variants.
SeqGAN is analgorithm of high variance, which relies on pre-training viaMaximum Likelihood Estimation as a variance reductionprocedure. During the adversarial epochs, even if withvariance reduction techniques such as Actor-Critic methods(Sutton, 1984), the fact that SeqGAN is essentially basedon model-free reinforcement learning makes it a non-trivialproblem for SeqGAN to converge well. One consequentresult is the “mode collapse” problem, which is similar to oT: Cooperative Training for Generative Modeling of Discrete Data
Algorithm 1
Cooperative Training
Require:
Generator G θ ; mediator M φ ; samples from real datadistribution p data ; hyper-parameter N m .1: Initialize G θ , M φ with random weights θ, φ .2: repeat for N m steps do
4: Collect two equal-sized mini-batch of samples { s g } and { s p } from G θ and p data , respectively5: Mix { s g } and { s p } as { s }
6: Update mediator M φ with { s } via Eq. (9)7: end for
8: Generate a mini-batch of sequences { s } ∼ G θ
9: Update generator G θ with { s } by applying Eq. (14)10: until CoT converges the original GAN but more severe here. In this case, thelearned distribution “collapses” towards the minimizationof Reverse KL divergence, i.e. KL ( G (cid:107) p data ) , which leadsto the loss of diversity of generated samples. In other words,SeqGAN trains the model for better generative quality atthe cost of diversity.
3. Methodology
To be consistent with the goal that the target distributionshould be well-estimated in both quality and diversity senses, an ideal algorithm for such models should be ableto optimize a symmetric divergence or distance.For sequential discrete data modeling, since the data dis-tribution is decomposed into a sequential product of finite-dimension multinomial distributions (always based on thesoftmax form), the failures of effectively optimizing JSDwhen the generated and real data distributions are distant, asdiscussed in Arjovsky et al. (2017), will not appear. As such,to optimize JSD is feasible. However, to our knowledge, noprevious algorithms provide a direct, low-variance optimiza-tion of JSD. In this paper, we propose Cooperative Training(CoT), as shown in Algorithm 1, to directly optimize a well-estimated JSD for training such models. Figure 1 illustratesthe whole Cooperative Training process.
Maximum Likelihood EstimationSamplesSamplesMinimize
DataGenerator Mediator
Figure 1.
Process of Cooperative Training. HE O BJECTIVE FOR M EDIATOR
Each iteration of Cooperative Training mainly consists oftwo parts. The first part is to train a mediator M φ , whichis a density function that estimates a mixture distributionof the learned generative distribution G θ and target latentdistribution p data as M φ (cid:39)
12 ( p data + G θ ) . (7)Since the mediator is only used as a density prediction module during training, the directed KL divergence is nowgreatly relieved from so-called exposure bias for optimiza-tion of M φ . Denote ( p data + G θ ) as M ∗ , we have: Lemma 1 (Mixture Density Decomposition) ∇ φ J m ( φ )= ∇ φ KL ( M ∗ (cid:107) M φ )= ∇ φ E s ∼ M ∗ (cid:104) log M ∗ ( s ) M φ ( s ) (cid:105) = ∇ φ (cid:16) − E s ∼ M ∗ [log M φ ( s )] (cid:17) = ∇ φ (cid:16) E s ∼ G θ [ − log( M φ ( s ))] + E s ∼ p data [ − log( M φ ( s ))] (cid:17) (8)By Lemma 1, for each step, we can simply mix balancedsamples from training data and the generator, then trainthe mediator via Maximum Likelihood Estimation with themixed samples. The objective J m ( φ ) for the mediator M parameterized by φ therefore becomes J m ( φ ) = 12 (cid:16) E s ∼ G θ [ − log( M φ ( s ))] + E s ∼ p data [ − log( M φ ( s ))] (cid:17) . (9)The training techniques and details will be discussed inSection 4.After each iteration, the mediator is exploited to optimizean estimated Jensen-Shannon divergence for G θ : J g ( θ )= − ˆ JSD ( G θ (cid:107) p data )= − (cid:2) KL ( G θ (cid:107) M φ ) + KL ( p data (cid:107) M φ ) (cid:3) = − E s ∼ G θ (cid:104) log G θ ( s ) M φ ( s ) (cid:105) − E s ∼ p data (cid:104) log p data ( s ) M φ ( s ) (cid:105) (10)When calculating ∇ θ J g ( θ ) , the second term has no effect onthe final results. Thus, we could use this objective instead: J g ( θ ) = − E s ∼ G θ (cid:104) log G θ ( s ) M φ ( s ) (cid:105) . (11) oT: Cooperative Training for Generative Modeling of Discrete Data ENERATOR O BJECTIVE AND M ARKOV B ACKWARD R EDUCTION
For any sequence or prefix of length t , we have: Lemma 2 (Markov Backward Reduction) − E s t ∼ G θ (cid:104) log G θ ( s t ) M φ ( s t ) (cid:105) = − E s t − ∼ G θ (cid:104) (cid:88) s t G θ ( s t | s t − ) log G θ ( s t | s t − ) M φ ( s t | s t − ) (cid:105) − E s t − ∼ G θ (cid:104) log G θ ( s t − ) M φ ( s t − ) (cid:105) . (12)The detailed derivations can be found in the supplementarymaterial. Note that Lemma 2 can be applied recursively.That is to say, given any sequence s t of arbitrary length t , optimizing s t ’s contribution to the expected JSD canbe decomposed into optimizing the first term of Eq. (12)and solving an isomorphic problem for s t − , which is thelongest proper prefix of s t . When t = 1 , since in Markovdecision process the probability for initial state s is always1.0, it is trivial to prove that the final second term becomes0.Therefore, Eq. (11) can be reduced through recursively ap-plying Lemma 2. After removing the constant multipliersand denoting the predicted probability distribution over theaction space, i.e. G θ ( ·| s t ) and M φ ( ·| s t ) , as π g ( s t ) and π m ( s t ) respectively, the gradient ∇ θ J g ( θ ) for training gen-erator via Cooperative Training can be formulated as J g ( θ ) = n − (cid:88) t =0 E s t ∼ G θ (cid:2) π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t )) (cid:3) . (13)For tractable density models with finite discrete action spacein each step, the practical availability of this objective’s gra-dient is well guaranteed for the following reasons. First,with a random initialization of the model, the supports ofdistributions G θ and P are hardly disjoint. Second, the firstterm of Eq. (13) is to minimize the cross entropy between G and M ∗ , which tries to enlarge the overlap of two distribu-tions. Third, since the second term of Eq. (13) is equivalentto maximizing the entropy of G , it encourages the supportof G to cover the whole action space, which avoids the caseof disjoint supports between G and P .3.1.3. F ACTORIZING THE C UMULATIVE G RADIENT T HROUGH T IME FOR I MPROVED T RAINING
Up to now, we are still not free from REINFORCE, as theobjective Eq. (13) incorporates expectation over the learneddistribution G θ . In this part, we propose an effective way to eventually avoid using REINFORCE. ∇ θ J g ( θ )= ∇ θ (cid:32) n − (cid:88) t =0 E s t ∼ G θ (cid:2) π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t )) (cid:3) (cid:33) For time step t , the gradient of Eq. (13) can be calculated as ∇ θ J g,t ( θ )= ∇ θ (cid:20) E s t ∼ G θ π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t )) (cid:21) = ∇ θ (cid:34)(cid:88) s t G θ ( s t )( π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t ))) (cid:35) = (cid:88) s t ∇ θ (cid:2) G θ ( s t )( π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t ))) (cid:3) . Let L ( s t ) = π g ( s t ) (cid:62) (log π m ( s t ) − log π g ( s t )) , then ∇ θ J g,t ( θ )= (cid:88) s t ( ∇ θ G θ ( s t ) L ( s t ) + G θ ( s t ) ∇ θ L ( s t ))= (cid:88) s t G θ ( s t ) ( ∇ θ log G θ ( s t ) L ( s t ) + ∇ θ L ( s t ))= E s t ∼ G θ ∇ θ [ stop gradient ( L ( s t )) log G θ ( s t ) + L ( s t )] . The total gradient in each step consists of two terms. Thefirst term stop gradient ( L ( s t )) log G θ ( s t ) behaves like RE-INFORCE, which makes the main contribution to thevariance of the optimization process. The second non-REINFORCE term is comparatively less noisy, though forthe first sight it seems not to work alone.Considering the effects of the two terms, we argue that theyhave similar optimization directions (towards minimizationof KL ( G θ (cid:107) M φ ) ). To study and control the balance of thetwo terms, we introduce an extra hyper-parameter γ ∈ [0 , ,to control the balance of the high-variance first term andlow-variance second term. The objective in each time stepthus becomes ∇ θ J γg,t ( θ )= E s t ∼ G θ ∇ θ [ γ ( stop gradient ( L ( s t )) log G θ ( s t )) + L ( s t )] . In the experiment part, we will show that the algorithmworks fine when γ = 0 . and the bias of the finally adoptedterm is acceptable. In practice, we could directly drop the oT: Cooperative Training for Generative Modeling of Discrete Data REINFORCE term, the total gradient would thus become ∇ θ J . g ( θ ) = n − (cid:88) t =0 E s t ∼ G θ (cid:20) ∇ θ π g ( s t ) (cid:62) (cid:16) log π m ( s t ) π g ( s t ) (cid:17)(cid:21) . (14) ONNECTION WITH A DVERSARIAL T RAINING
The overall objective of CoT can be regarded as finding asolution of max θ max φ E s ∼ p data [log( M φ ( s ))] + E s ∼ G θ [log( M φ ( s ))] . (15)Note the strong connections and differences between theoptimization objective of CoT (15) and that of GAN (4)lie in the max-max and minimax operations of the jointobjective.3.2.2. A DVANTAGES OVER P REVIOUS M ETHODS
CoT has several practical advantages over previous methods,including MLE, Scheduled Sampling (SS) (Bengio et al.,2015) and adversarial methods like SeqGAN (Yu et al.,2017).First, although CoT and GAN both aim to optimize an esti-mated JSD, CoT is exceedingly more stable than GAN. Thisis because the two modules, namely generator and mediator,have similar tasks, i.e. to approach the same data distribu-tion as generative and predictive models, respectively. Thesuperiority of CoT over inconsistent methods like ScheduledSampling is solid, since CoT has a systematic theoreticalexplanation of its behavior. Compared with methods that re-quire pre-training in order to reduce variance like SeqGAN(Yu et al., 2017), CoT is computationally cheaper. Morespecifically, under recommended settings, CoT has the sameorder of computational complexity as MLE.Besides, CoT works independently. In practice, it does notrequire model pre-training via conventional methods likeMLE. This is an important property of an unsupervisedlearning algorithm for sequential discrete data without usingsupervised approximation for variance reduction or sophis-ticated smoothing as in Wasserstein GAN with gradientpenalty (WGAN-GP) (Gulrajani et al., 2017).3.2.3. T HE N ECESSITY OF THE M EDIATOR
An interesting problem is to ask why we need to train amediator by mixing the samples from both sources G and P , instead of directly training a predictive model ˆ P on thetraining set via MLE. There are basically two points tointerpret this.To apply the efficient training objective Eq. (13), one needs to obtain not only the mixture density model M = ( P + G ) but also its decomposed form in each timestep i.e. M φ ( s ) = (cid:81) nt =1 M φ ( s t | s t − ) , without which theterm π m ( s t ) in Eq. (13) cannot be computed efficiently.This indicates that if we directly estimate P and compute M = ( G + P ) , the obtained M will be actually uselesssince its decomposed form is not available.Besides, as a derivative problem of “exposure bias”, themodel ˆ P would have to generalize to work well on thegenerated samples i.e. s ∼ G θ to guide the generator to-wards the target distribution. Given finite observations, thelearned distribution ˆ P is trained to provide correct predic-tions for samples from the target distribution P . There isno guarantee that ˆ P can stably provide correct predictionsfor guiding the generator. Ablation study is provided in thesupplementary material.
4. Experiments
Following the synthetic data experiment setting in Yu et al.(2017); Zhu et al. (2018), we design a synthetic Turingtest, in which the negative log-likelihood NLL oracle froman oracle LSTM is calculated for evaluating the quality ofsamples from the generator.Particularly, to support our claim that our method causeslittle mode collapse, we calculated NLL test , which is tosample an extra batch of samples from the oracle, and tocalculate the negative log-likelihood measured by the gener-ator.We show that under this more reasonable setting, our pro-posed algorithm reaches the state-of-the-art performancewith exactly the same network architecture. Note that mod-els like LeakGAN (Guo et al., 2017) contain architecture-level modification, which is orthogonal to our approach, thuswill not be included in this part. The results are shown inTable 1. Code for repeatable experiments of this subsectionis provided in supplementary materials.4.1.1. E
MPIRICAL A NALYSIS OF E STIMATED G RADIENTS
As a part of the synthetic experiment, we demonstrate theempirical effectiveness of the estimated gradient. Duringthe training of CoT model, we record the statistics of thegradient with respect to model parameters estimated byback-propagating ∇ θ J . g ( θ ) and ∇ θ J . g ( θ ) , including themean and log variance of such gradients.We are mainly interested in two properties of the estimatedgradients, which can be summarized as: oT: Cooperative Training for Generative Modeling of Discrete Data Table 1.
Likelihood-based benchmark and time statistics for synthetic Turing test. ‘-(MLE)’ means the best performance is acquiredduring MLE pre-training.M
ODEL
NLL oracle
NLL test ( FINAL / BEST ) BEST
NLL oracle + NLL test
TIME / EPOCH
MLE 9.08 8.97/7.60 9.43 + 7.67 ± S S EQ GAN(Y
U ET AL ., 2017) 8.68 10.10/-(MLE) - (MLE) . ± . s R ANK
GAN(L
IN ET AL ., 2017) 8.37 11.19/-(MLE) - (MLE) ± . s M ALI
GAN(C
HE ET AL ., 2017) 8.73 10.07/-(MLE) - (MLE) . ± . s S CHEDULED S AMPLING . ± . s (B ENGIO ET AL ., 2015)P
ROFESSOR F ORCING . ± . s (L AMB ET AL ., 2016)C O T (
OURS ) . ± . s g-steps=3,d-steps=1g-steps=2,d-steps=1g-steps=1,d-steps=1g-steps=1,d-steps=2g-steps=1,d-steps=3g-steps=1,d-steps=4 J S D Epochs Pre-training Adversarial training (a) JSD of SeqGAN N LL O r a c l e learning rate=1e-4learning rate=1e-3learning rate=5e-3learning rate=1e-2learning rate=1e-1 Iterations (b) NLL oracle of CoT g-steps=3, m-steps=1g-steps=2, m-steps=1g-steps=1, m-steps=1g-steps=1, m-steps=2g-steps=1, m-steps=3g-steps=1, m-steps=4 J S D (cid:3) Epochs (c) JSD of CoT
Figure 2.
Curves of evaluation on JSD, NLL oracle during iterations of CoT under different training settings. To show the hyperparameterrobustness of CoT, we compared it with a typical language GAN i.e.
SeqGAN (Yu et al., 2017). • Bias
Obviously, ∇ θ J . g ( θ ) is exactly the originalgradient which is unbiased towards the minimizationof Eq. (13). If the estimated gradient ∇ θ J . g ( θ ) ishighly biased, the cosine similarity of the averageof ∇ θ J . g ( θ ) and ∇ θ J . g ( θ ) would be close to 0.0,otherwise it would be close to 1.0. To investigatethis, we calculate the cosine similarity of expected ∇ θ J . g ( θ ) and ∇ θ J . g ( θ ) . • Variance
We calculate the log variance of ∇ θ J . g ( θ ) and ∇ θ J . g ( θ ) in each dimension, and compute theaverage log variance of each variance. In the figure, tobetter illustrate the comparison, we plot the advantageof mean log variance of ∇ θ J . g ( θ ) over ∇ θ J . g ( θ ) . Ifthe variance of the estimated gradient is lower, such astatistic would be steadily positive.To calculate these statistics, we sample 3,000 sequencesfrom the generator and calculate the average gradient undereach settings every 100 iterations during the training of themodel. The results are shown in Figure 3. The estimatedgradient of our approach shows both properties of low biasand effectively reduced variance. 4.1.2. D ISCUSSION
Computational Efficiency
Although in terms of time costper epoch, CoT does not achieve the state-of-the-art, wedo observe that CoT is remarkably faster than previouslanguage GANs. Besides, consider the fact that CoT is asample-based optimization algorithm, which involves timecost in sampling from the generator, this result is acceptable.The result also verifies our claim that CoT has the sameorder ( i.e. the time cost only differs in a constant multiplieror extra lower order term) of computational complexity asMLE.
Hyper-parameter Robustness
We perform a hyper-parameter robustness experiment on synthetic data exper-iment. When compared with the results of similar experi-ments as in SeqGAN (Yu et al., 2017), our approach showsless sensitivity to hyper-parameter choices, as shown inFigure 2. Note that in all our attempts, the curves of theevaluated JSD of SeqGAN fail to converge.
Self-estimated Training Progress Indicator
Like thecritic loss, i.e. estimated Earth Mover Distance, in WGANs,we find that the training loss of the mediator (9), namely balanced NLL , can be a real-time training progress indicatoras shown in Figure 4. Specifically, in a wide range, balanced oT: Cooperative Training for Generative Modeling of Discrete Data c o s i n e s i m il a r i t y Iterations (a) Curves of cosine similarity of averaged ∇ θ J . g ( θ ) and ∇ θ J . g ( θ ) during training. △ l o g ( v a r i a n c e p e r d i m ) Iterations (b) Curves of log variance reduction per dimension of ∇ θ J . g ( θ ) compared to ∇ θ J . g ( θ ) Figure 3.
Empirical study on bias and variance comparison.
NLL is a good estimation of real
JSD ( G (cid:107) P ) with a steadytranslation, namely, N LL balanced = 2
JSD ( G (cid:107) P ) + H ( G ) + H ( P ) . As an important sequential data modeling task, zero-priortext generation, especially long and diversified text genera-tion, is a good testbed for evaluating the performance of agenerative model.Following the experiment proposed in LeakGAN (Guo et al.,2017), we choose EMNLP 2017 WMT News Section as ourdataset, with maximal sentence length limited to 51. We paymajor attention to both quality and diversity . To keep thecomparison fair, we present two implementations of CoT,namely CoT-basic and CoT-strong. As for CoT-basic, thegenerator follows the settings of that in MLE, SeqGAN,RankGAN and MaliGAN. As for CoT-strong, the generatoris implemented with the similar architecture in LeakGAN.For quality evaluation, we evaluated BLEU on a small batchof test data separated from the original dataset. For diver-sity evaluation, we evaluated the estimated Word MoverDistance (Kusner et al., 2015), which is calculated through O r a c l e J S D (a) Curves of JSD ( G (cid:107) P ) during training for MLE, Se-qGAN and CoT. (b) Curves of balanced NLL and real JSD. Both resultsare from synthetic data experiments. Figure 4.
Training progress curves indicated by different values. training a discriminative model between generated samplesand real samples with 1-Lipschitz constraint via gradientpenalty as in WGAN-GP (Gulrajani et al., 2017). To keepit fair, for all evaluated models, the architecture and othertraining settings of the discriminative models are kept thesame.The results are shown in Table 2 and Table 3. In terms ofgenerative quality, CoT-basic achieves state-of-the-art per-formance over all the baselines with the same architecture-level capacity, especially the long-term robustness at n-gramlevel. CoT-strong using a conservative generation strategy, i.e. setting the inverse temperature parameter α higher than1, as in (Guo et al., 2017) achieves the best performanceover all compared models. In terms of generative diversity,the results show that our model achieves the state-of-the-artperformance on all metrics including NLL test , which is theoptimization target of MLE. Implementation Details of eWMD
To calculate eWMD,we adopted a multi-layer convolutional neural network asthe feature extractor. We calculate the gradient w.r.t. theone-hot representation O s of the sequence s for gradientpenalty. The training loss of the Wasserstein critic f ω can oT: Cooperative Training for Generative Modeling of Discrete Data Table 2.
N-gram-level quality benchmark: BLEU on test data ofEMNLP2017 WMT News.*: Results under the conservative generation settings as is describedin LeakGAN’s paper.M
ODEL
BLEU2 BLEU3 BLEU4 BLEU5MLE 0.781 0.482 0.225 0.105S EQ GAN 0.731 0.426 0.181 0.096R
ANK
GAN 0.691 0.387 0.178 0.095M
ALI
GAN 0.755 0.456 0.179 0.088L
EAK
GAN* 0.835 0.648 0.437 0.271C O T- BASIC O T- STRONG O T- STRONG * Table 3.
Diversity benchmark: estimated Word Mover Distance(eWMD) and NLL test M ODEL E
WMD test E WMD train
NLL test
MLE 1.015 σ =0.023 σ =0.019 EQ GAN 2.900 σ =0.025 σ =0.018 ANK
GAN 4.451 σ =0.083 σ =0.021 ALI
GAN 4.891 σ =0.061 σ =0.020 EAK
GAN 1.803 σ =0.027 σ =0.023 O T- BASIC σ =0.031 σ =0.019 O T- STRONG σ =0.018 σ =0.016 be formulated as L c ( ω, λ ) = E s ∼ G θ [ f ω ( O s )] − E s ∼ p data [ f ω ( O s )]+ λ max(0 , (cid:107)∇ f ω ( ˆ O ) (cid:107) − , where ˆ O = (1 − µ ) O s p + µO s q µ ∼ Uniform (0 , s q ∼ G θ s p ∼ p data . We use Adam (Kingma & Ba, 2014) as the optimizer, withhyper-parameter settings of α = 1 e − , β = 0 . , β =0 . . For each evaluated generator, we train the critic f ω for100,000 iterations, and calculate eWMD( p data , G θ ) as E s ∼ p data [ f ω ( O s )] − E s ∼ G θ [ f ω ( O s )] . The network architecture for f ω is shown in Table 4.
5. Future Work & Conclusion
We proposed Cooperative Training, a novel algorithm fortraining generative models of discrete data. CoT achieves
Table 4.
Detailed implementation of eWMD network architecture.Word Embedding Layer, hidden dim = 128
Conv1d, window size = 2 , strides = 1 , channels = 64
Leaky ReLU Nonlinearity ( α = 0 . )Conv1d, window size = 3 , strides = 2 , channels = 64 Leaky ReLU Nonlinearity ( α = 0 . )Conv1d, window size = 3 , strides = 2 , channels = 128 Leaky ReLU Nonlinearity ( α = 0 . )Conv1d, window size = 4 , strides = 2 , channels = 128 Leaky ReLU Nonlinearity ( α = 0 . )FlattenFully Connected, output dimension = 512 Leaky ReLU Nonlinearity ( α = 0 . )Fully Connected, output dimension = 1 independent success without the necessity of pre-training viamaximum likelihood estimation or involving REINFORCE.In our experiments, CoT achieves superior performance onsample quality, diversity, as well as training stability.As for future work, one direction is to explore whetherthere is better way to factorize the dropped term of Eq. (14)into some low-variance term plus another high-varianceresidual term. This would further improve the performanceof models trained via CoT. Another interesting direction is toinvestigate whether there are feasible factorization solutionsfor the optimization of other distances/divergences, such asWasserstein Distance, total variance and other task-specificmeasurements.
6. Acknowledgement
The corresponding authors Sidi Lu and Weinan Zhang thankthe support of National Natural Science Foundation of China(61702327, 61772333, 61632017), Shanghai Sailing Pro-gram (17YF1428200).
References
Arjovsky, M. and Bottou, L. Towards principled methods fortraining generative adversarial networks. arXiv preprintarXiv:1701.04862 , 2017.Arjovsky, M., Chintala, S., and Bottou, L. Wasserstein gan. arXiv:1701.07875 , 2017.Bahdanau, D., Cho, K., and Bengio, Y. Neural machinetranslation by jointly learning to align and translate. arXiv:1409.0473 , 2014.Bengio, S., Vinyals, O., Jaitly, N., and Shazeer, N. Sched- oT: Cooperative Training for Generative Modeling of Discrete Data uled sampling for sequence prediction with recurrent neu-ral networks. In
NIPS , pp. 1171–1179, 2015.Che, T., Li, Y., Zhang, R., Hjelm, R. D., Li, W., Song,Y., and Bengio, Y. Maximum-likelihood augmented dis-crete generative adversarial networks. arXiv:1702.07983 ,2017.Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B.,Warde-Farley, D., Ozair, S., Courville, A., and Bengio,Y. Generative adversarial nets. In
NIPS , pp. 2672–2680,2014.Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., andCourville, A. C. Improved training of wasserstein gans.In
NIPS , pp. 5769–5779, 2017.Guo, J., Lu, S., Cai, H., Zhang, W., Yu, Y., and Wang, J.Long text generation via adversarial training with leakedinformation. arXiv:1709.08624 , 2017.Kingma, D. P. and Ba, J. Adam: A method for stochasticoptimization. arXiv preprint arXiv:1412.6980 , 2014.Kusner, M., Sun, Y., Kolkin, N., and Weinberger, K. Fromword embeddings to document distances. In
InternationalConference on Machine Learning , pp. 957–966, 2015.Lamb, A. M., GOYAL, A. G. A. P., Zhang, Y., Zhang, S.,Courville, A. C., and Bengio, Y. Professor forcing: Anew algorithm for training recurrent networks. In
NIPS ,pp. 4601–4609, 2016.Lin, K., Li, D., He, X., Zhang, Z., and Sun, M.-T. Ad-versarial ranking for language generation. In
NIPS , pp.3155–3165, 2017.Lu, S., Zhu, Y., Zhang, W., Wang, J., and Yu, Y. Neuraltext generation: Past, present and beyond. arXiv preprintarXiv:1803.07133 , 2018.Radford, A., Metz, L., and Chintala, S. Unsupervised rep-resentation learning with deep convolutional generativeadversarial networks. arXiv preprint arXiv:1511.06434 ,2015.Ranzato, M., Chopra, S., Auli, M., and Zaremba, W. Se-quence level training with recurrent neural networks. arXiv preprint arXiv:1511.06732 , 2015.Sutton, R. S. Temporal credit assignment in reinforcementlearning. 1984.Ulyanov, D., Vedaldi, A., and Lempitsky, V. Instance nor-malization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022 , 2016.Yu, L., Zhang, W., Wang, J., and Yu, Y. Seqgan: Sequencegenerative adversarial nets with policy gradient. In
AAAI ,pp. 2852–2858, 2017. Zhu, Y., Lu, S., Zheng, L., Guo, J., Zhang, W., Wang, J.,and Yu, Y. Texygen: A benchmarking platform for textgeneration models. arXiv:1802.01886 , 2018. D ETAILED D ERIVATION OF THE A LGORITHM (11) = − E st ∼ Gθ [log G θ ( s t ) − log M φ ( s t )] ! = − E st ∼ Gθ h G θ ( s t − ) G θ ( s t | s t − ) G θ ( s t ) (log G θ ( s t ) − log M φ ( s t )) i! = − E st ∼ Gθ h G θ ( s t − ) G θ ( s t | s t − ) G θ ( s t ) (cid:0) log G θ ( s t | s t − ) G θ ( s t − ) − log M φ ( s t | s t − ) M φ ( s t − ) (cid:1)i! = − X st G θ ( s t − ) G θ ( s t | s t − ) (cid:0) log G θ ( s t | s t − ) − log M φ ( s t | s t − ) (cid:1) + X st G θ ( s t − ) G θ ( s t | s t − ) log G θ ( s t − ) M φ ( s t − ) ! = − X st G θ ( s t − ) G θ ( s t | s t − ) (cid:0) log G θ ( s t | s t − ) − log M φ ( s t | s t − ) (cid:1) + X st − (cid:16) G θ ( s t − ) log G θ ( s t − ) M φ ( s t − ) (cid:17) X st G θ ( s t | s t − ) ! (here s t − iterates over all prefixes of the sequences in { s t } ) = − X st G θ ( s t − ) G θ ( s t | s t − ) (cid:0) log G θ ( s t | s t − ) − log M φ ( s t | s t − ) (cid:1) + X st − G θ ( s t − ) log G θ ( s t − ) M φ ( s t − ) ! = − X st G θ ( s t − ) G θ ( s t | s t − ) (cid:0) log G θ ( s t | s t − ) − log M φ ( s t | s t − ) (cid:1) + E st − ∼ Gθ h log G θ ( s t − ) M φ ( s t − ) i! = − X st − G θ ( s t − ) X st G θ ( s t | s t − ) (cid:0) log G θ ( s t | s t − ) − log M φ ( s t | s t − ) (cid:1) + E st − ∼ Gθ h log G θ ( s t − ) M φ ( s t − ) i! =(12) B S
AMPLE C OMPARISON AND D ISCUSSION
Table 1 shows samples from some of the most powerful baseline models and our model.Observation of the model samples indicates that: • CoT produces remarkably more diverse and meaningful samples when compared to Leak-GAN. • The consistency of CoT is significantly improved when compared to MLE.
C F
URTHER D ISCUSSIONS ABOUT THE E XPERIMENT R ESULTS
The Optimal Balance for Cooperative Training
We find that the same learning rate and iterationnumbers for the generator and mediator seems to be the most competitive choice. As for thearchitecture choice, we find that the mediator needs to be slightly stronger than the generator. For thebest result in the synthetic experiment, we adopt exactly the same generator as other compared modelsand a mediator whose hidden state size is twice larger (with 64 hidden units) than the generator.Theoretically speaking, we can and we should sample more batches from G θ and P respectively fortraining the mediator in each iteration. However, if no regularizations are used when training themediator, it can easily over-fit, leading the generator’s quick convergence in terms of KL ( G θ k P ) or NLL oracle , but divergence in terms of JSD ( G θ k P ) . Empirically, this could be alleviated byapplying dropout techniques (Srivastava et al., 2014) with 50% keeping ratio before the output layerof RNN. After applying dropout, the empirical results show good consistency with our theory that,more training batches for the mediator in each iteration is always helpful.However, applying regularizations is not an ultimate solution and we look forward to further theoreti-cal investigation on better solutions for this problem in the future.1able 1: WMT News Samples from Different Models Sources ExampleLeakGAN (1) It’s a big advocate for therapy is a second thing to do, and I’m creating a relationshipwith a nation.(2) It’s probably for a fantastic footage of the game, but in the United States is alreadytime to be taken to live.(3) It’s a sad House we have a way to get the right because we have to go to see that, ” shesaid.(4) I’m not sure if I thank a little bit easier to get to my future commitment in work, ” hesaid.(5) “ I think it was alone because I can do that, when you’re a lot of reasons, ” he said.(6) It’s the only thing we do, we spent 26 and $35(see how you do is we lose it,” said bothsides in the summer.CoT (1) We focus the plans to put aside either now, and which doesn’t mean it is to earn theimpact to the government rejected.(2) The argument would be very doing work on the 2014 campaign to pursue the firm andimmigration officials, the new review that’s taken up for parking.(3) This method is true to available we make up drink with that all they were willing topay down smoking.(4) The number of people who are on the streaming boat would study if the children had abottle - but meant to be much easier, having serious ties to the outside of the nation.(5) However, they have to wait to get the plant in federal fees and the housing market’smost valuable in tourism.MLE (1) after the possible cost of military regulatory scientists, chancellor angela merkel’sbusiness share together a conflict of major operators and interest as they said it is unknownfor those probably 100 percent as a missile for britain.(2) but which have yet to involve the right climb that took in melbourne somewhere elsewith the rams even a second running mate and kansas.(3) “ la la la la 30 who appeared that themselves is in the room when they were shot heruntil the end ” that jose mourinho could risen from the individual .(4) when aaron you has died, it is thought if you took your room at the prison fines ofradical controls by everybody, if it’s a digital plan at an future of the next time.
Possible Derivatives of CoT
The form of equation 11 can be modified to optimize other objectives.One example is the backward KLD ( a.k.a.
Reverse KLD) i.e. KL ( G k P ) . In this case, the objectiveof the so-called “Mediator” and “Generator” thus becomes:“Mediator”, now it becomes a direct estimator ˆ P φ of the target distribution P : J ˆ p ( φ ) = E s ∼ P [ − log( ˆ P φ ( s ))] . (1)Generator: J g ( θ ) = n − X t =0 E s t ∼ G θ h π g ( s t ) > (log π ˆ p ( s t ) − log π g ( s t )) i . (2)Such a model suffers from so-called mode-collapse problem, as is analyzed in Ian’s GAN Tutorial(Goodfellow, 2016). Besides, as the distribution estimator ˆ P φ inevitably introduces unpredictablebehaviors when given unseen samples i.e. samples from the generator, the algorithm sometimes fails(numerical error) or diverges.In our successful attempts, the algorithm produces similar (not significantly better than) results asCoT. The quantitive results are shown as follows:Table 2: N-gram-level quality benchmark: BLEU on test data of EMNLP2017 WMT News (NewSplit)
Model/Algorithm BLEU-2 BLEU-3 BLEU-4 BLEU-5 eWMDCoT-basic (ours) 0.850 0.571 0.316 0.169 σ = 0 . ) Reverse KL (ours) σ = 0 . ) e.g. completely covering the data mode. R EFERENCES
Ian Goodfellow. Nips 2016 tutorial: Generative adversarial networks. arXiv preprintarXiv:1701.00160 , 2016.Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov.Dropout: A simple way to prevent neural networks from overfitting.