Domain Invariant Representation Learning with Domain Density Transformations
DDomain Invariant Representation Learning with Domain DensityTransformations
A. Tuan Nguyen Toan Tran Yarin Gal Atılım Güne¸s Baydin Abstract
Domain generalization refers to the problemwhere we aim to train a model on data from aset of source domains so that the model can gen-eralize to unseen target domains. Naively traininga model on the aggregate set of data (pooled fromall source domains) has been shown to performsuboptimally, since the information learned bythat model might be domain-specific and general-ize imperfectly to target domains. To tackle thisproblem, a predominant approach is to find andlearn some domain-invariant information in orderto use it for the prediction task. In this paper, wepropose a theoretically grounded method to learna domain-invariant representation by enforcingthe representation network to be invariant underall transformation functions among domains. Wealso show how to use generative adversarial net-works to learn such domain transformations toimplement our method in practice. We demon-strate the effectiveness of our method on severalwidely used datasets for the domain generaliza-tion problem, on all of which we achieve compet-itive results with state-of-the-art models.
1. Introduction
Domain generalization refers to the machine learning sce-nario where the model is trained on multiple source domainsso that it is expected to generalize well to unseen targetdomains. The key difference between domain generaliza-tion (Khosla et al., 2012; Muandet et al., 2013; Ghifary et al.,2015) and domain adaptation (Zhao et al., 2019; Zhang et al.,2019; Combes et al., 2020; Tanwani, 2020) is that, in do-main generalization, the learner does not have access to(even a small amount of) data of the target domain, makingthe problem much more challenging.One of the most common domain generalization approachesis to learn an invariant representation across domains, aim-ing at a good generalization performance on target domains. University of Oxford VinAI Research. Correspondence to: A.Tuan Nguyen
Domain 1: 𝑧 = 𝑥/ 𝑥 ( with color indicating y Domain 2: 𝑧 = 𝑥/ 𝑥 ( with color indicating y Figure 1.
An example of two domains . For each domain, x isuniformly distributed on the outer circle (radius 2 for domain 1and radius 3 for domain 2), with the color indicating class label y .After the transformation z = x/ || x || , the marginal of z is aligned(uniformly distributed on the unit circle for both domains), but theconditional p ( y | z ) is not aligned. Thus, using this representationfor predicting y would not generalize well across domains. In the representation learning framework, the prediction y = f ( x ) , where x is data and y is a label, is obtained as acomposition y = h ◦ g ( x ) of a deep representation network z = g ( x ) , where z is a learned representation of data x ,and a smaller classifier y = h ( z ) , predicting label y givenrepresentation z , both of which are shared across domains.Current “domain-invariance”-based methods in domain gen-eralization focus on either the marginal distribution align-ment (Muandet et al., 2013) or the conditional distributionalignment (Li et al., 2018b;c), which are still prone to distri-butional shifts if the conditional or marginal (respectively)data distribution is not stable. In particular, the marginalalignment refers to making the representation distribution p ( z ) to be the same across domains. This is essential sinceif p ( z ) for the target domain is different from that of sourcedomains, the classification network h ( z ) would face out-of-distribution data because the representation z it receivesas input at test time would be different from the ones itwas trained with in source domains. Conditional alignmentrefers to aligning p ( y | z ) , the conditional distribution ofthe label given the representation, since if this conditionalfor the target domain is different from that of the sourcedomains, the classification network (trained on the sourcedomains) would give inaccurate predictions at test time.The formal definition of the two alignments is discussed inSection 3. a r X i v : . [ c s . L G ] F e b omain Invariant Representation Learning with Domain Density Transformations In Figure 1 we illustrate an example where the represen-tation z satisfies the marginal alignment but not the condi-tional alignment. Specifically, x is distributed uniformlyon the circle with radius 2 (and centered at the origin) fordomain 1 and distributed uniformly on the circle with radius3 (centered at the origin) for domain 2. The representation z defined by the mapping z = g ( x ) = x/ || x || will alignthe marginal distribution p ( z ) , i.e., z is now distributed uni-formly on the unit circle for both domains. However, theconditional distribution p ( y | z ) is not aligned between thetwo domains ( y is represented by color), which means usingthis representation for classification is suboptimal, and inthis extreme case would lead to 0% accuracy in the targetdomain 2. This is an extreme case of misalignment but itdoes illustrate the importance of the conditional alignment.Therefore, we need to align both the marginal and the con-ditional distributions for a domain-invariant representation.There have been several attempts recently to align both themarginal and conditional distribution in a domain adaptationproblem, for example, (Tanwani, 2020), by leveraging asmall set of labeled data of the target domain. However,it is challenging to apply this approach directly to domaingeneralization because we do not have access to data in thetarget domain.In this paper, we focus on learning a domain-invariant rep-resentation that aligns both the marginal and the conditionaldistribution in domain generalization problems. We presenttheoretical results regarding the conditions for the existenceof domain-invariant representations, and subsequently pro-pose a method to learn such representations based on domaindensity transformation functions. A simple intuition for ourapproach is that if we enforce the representation to be invari-ant under the transformations among source domains, therepresentation will become more robust under other domaintransformations.Furthermore, we introduce an implementation of our methodin practice, in which the domain transformation functionsare learned through the training process of generative ad-versarial networks (GANs) (Goodfellow et al., 2014; Choiet al., 2018). We conduct extensive experiments on severalwidely used datasets and observe a significant improvementover the naive baseline of training a model normally on theaggregate dataset from all domains. We also compare ourmethods against other state-of-the-art models and show thatour method achieves competitive results.Our contribution in this work is threefold:• We shed light on the domain generalization problem byproviding several theoretical observations: a necessaryand sufficient condition for the existence of a domain-invariant representation and a connection betweendomain-independent representation and a marginally- aligned representation.• We propose a theoretically grounded method for learn-ing a domain-invariant representation based on domaindensity transformation functions. We also demonstratethat we can learn the domain transformation functionsby GANs in order to implement our approach in prac-tice.• We show the effectiveness of our methods by perform-ing experiments on widely used domain generalizationdatasets (e.g., Rotated MNIST, PACS and OfficeHome)and compare with relevant baselines (especially DGER(Zhao et al., 2020), a main baseline that also aims tolearn domain invariant representations).
2. Related Work
Domain Generalization:
Domain generalization is animportant task in real-world machine learning problemssince the data distribution of a target domain might varyfrom that of the source domains which a model is trained on.Therefore, extensive research has been developed focusingon learning a model that generalizes well to unseen targetdomains. While the literature is vast, here we cover the mostimportant works that are related to ours. A predominantapproach for domain generalization is domain invariance(Muandet et al., 2013; Li et al., 2018b;c; Arjovsky et al.,2019; Wang et al., 2020; Muandet et al., 2013; Akuzawaet al., 2019; Ilse et al., 2020; Zhao et al., 2020). Our methodfalls into this category since we propose a method that learnsa domain-invariant representation (which we define as toalign both the marginal distribution of the representationand the conditional distribution of the output given the rep-resentation). We consider DGER (Zhao et al., 2020), whichalso learns a representation that aligns both the marginaland conditional distribution via an adversarial loss and anentropy regularizer, one of the main baselines to ours. Itshould be noted that Zhao et al. (2020) assume the label isdistributed uniformly on all domains, which is stronger thanour assumption that the distribution of label is stable acrossdomains (and not necessarily uniform). We also show laterin our paper that the invariance of the distribution of classlabel across domains is indeed the necessary and sufficientcondition for the existence of a domain-invariant represen-tation. We provide a unified theoretical discussion aboutthe two alignments and a method to learn a representationthat aligns both the marginal and conditional distributionsvia domain density transformation functions for the domaingeneralization problem.Another line of methods that received a recent surge ininterest is applying the idea of meta-learning for domaingeneralization problems (Du et al., 2020; Balaji et al., 2018;Li et al., 2018a; Behl et al., 2019). The core idea behind omain Invariant Representation Learning with Domain Density Transformations 𝑑 𝑦 𝑥 𝑧
Figure 2.
Graphical model . Each domain d defines a data distribu-tion p ( x, y | d ) . We want to learn a representation z with a mappingfrom x so that p ( z | x ) can be generalized between domains. these works is that if we train a model that can adapt amongsource domains well, it would be more likely to adapt tounseen target domains.Finally, there are approaches (Ding & Fu, 2017; Chattopad-hyay et al., 2020; Seo et al., 2019) that make use of thedomain specificity, together with domain invariance, for theprediction problem. The argument here is that domain invari-ance, while being generalized well between domains, mightbe insufficient for the prediction of each specific domainand thus domain specificity is necessary.We would like to emphasize that our method is not a directcompetitor of meta-learning based and domain specificitybased methods. In fact, we expect that our method can beused in conjunction with these methods to get the best ofboth worlds for better performance. Density transformation between domains:
Since ourmethod is based on domain density transformations, wewill review briefly some related work here. To transformthe data density between domains, one can use several typesof generative models. Two common methods are based onGANs (Zhu et al., 2017; Choi et al., 2018; 2020) and nor-malizing flows (Grover et al., 2020). Although our methodis not limited to the choice of the generative model usedfor learning the domain transformation functions, we opt touse GAN, specifically StarGAN (Choi et al., 2018), for itsrich network capacity. This is just an implementation choiceto demonstrate the use and effectiveness of our method inpractice, and it is unrelated to our theoretical results.
Connection to contrastive learning:
Our method can beinterpreted intuitively as a way to learn a representation net-work that is invariant (robust) under domain transformationfunctions. On the other hand, contrastive learning (Chenet al., 2020a;b; Misra & Maaten, 2020) is also a represen-tation learning paradigm where the model learns images’similarity. In particular, contrastive learning encourages therepresentation of an input to be similar under different trans-formations (usually image augmentations). However, thetransformations in contrastive learning are not learned and do not serve the purpose of making the representation robustunder domain transformations. Our method first learns thetransformations between domains and then uses them tolearn a representation that is invariant under domain shifts.
3. Theoretical Approach
Let us define the data distribution for a domain d ∈ D by p ( x, y | d ) , where the variable x ∈ X represents the data and y ∈ Y is the corresponding label. The graphical modelfor our domain generalization framework is depicted inFigure 2, in which the joint distribution is presented asfollows: p ( d, x, y, z ) = p ( d ) p ( y ) p ( x | y, d ) p ( z | x ) . (1)In the domain generalization problem, the data distribution p ( x, y | d ) varies between domains, thus we expect changesin the marginal data distribution p ( x | d ) or the conditionaldata distribution p ( y | x, d ) or both. In this paper, we as-sume that p ( y | d ) is invariant across domains, i.e., y is notdependent on d —this assumption is shown to be the keycondition for the existence of a domain-invariant representa-tion (see Theorem 1). This is practically reasonable sincein many classification datasets, the class distribution can beassumed to be unchanged across domains (usually uniformdistribution among the classes, e.g., balanced datasets).Our aim is to find a domain-invariant representation z rep-resented by the mapping p ( z | x ) that can be used for theclassification of label y and be generalized among domains.In practice, this mapping can be deterministic (in that case, p ( z | x ) = δ g θ ( x ) ( z ) with some function g θ , where δ is theDirac delta distribution) or probabilistic (e.g., a normal dis-tribution with the mean and standard deviation outputted bya network parameterized by θ ). For all of our experiments,we use a deterministic mapping for an efficient inference attest time, while in this section, we present our theoreticalresults with the general case of a distribution p ( z | x ) .In most existing domain generalization approaches, thedomain-invariant representation z is defined using one ofthe two following definitions: Definition 1. (Marginal Distribution Alignment)
The rep-resentation z is said to satisfy the marginal distributionalignment condition if p ( z | d ) is invariant w.r.t. d . Definition 2. (Conditional Distribution Alignment)
Therepresentation z is said to satisfy the conditional distributionalignment condition if p ( y | z, d ) is invariant w.r.t. d . However, when the data distribution varies between do-mains, it is crucial to align both the marginal and the con-ditional distribution of the representation z . To this end, omain Invariant Representation Learning with Domain Density Transformations this paper aims to learn a representation z that satisfies boththe marginal and conditional alignment conditions. We jus-tify our assumption of independence between y and d (thus p ( y | d ) = p ( y ) ) by the following theorem, which shows thatthis assumption turns out to be the necessary and sufficientcondition for learning a domain-invariant representation. Theorem 1.
The invariance of p ( y | d ) across domains isthe necessary and sufficient condition for the existenceof a domain-invariant representation (that aligns both themarginal and conditional distribution).Proof. i) If there exists a representation z defined bythe mapping p ( z | x ) that aligns both the marginal andconditional distribution, then ∀ d, d (cid:48) , y we have: p ( y, z | d ) = p ( z | d ) p ( y | z, d )= p ( z | d (cid:48) ) p ( y | z, d (cid:48) ) = p ( y, z | d (cid:48) ) . (2)By marginalizing both sides of Eq 2 over z , we get p ( y | d ) = p ( y | d (cid:48) ) .ii) If p ( y | d ) is unchanged w.r.t. the domain d , then wecan always find a domain invariant representation, forexample, p ( z | x ) = δ ( z ) for the deterministic case(that maps all x to 0), or p ( z | x ) = N ( z ; 0 , for theprobabilistic case.These representations are trivial and not of our interestsince they are uninformative of the input x . How-ever, the readers can verify that they do align both themarginal and conditional distribution of data.It is also worth noting that methods which learn a domainindependent representation, for example, Ilse et al. (2020),only align the marginal distribution. This comes directlyfrom the following remark: Remark 1.
A representation z satisfies the marginal distri-bution alignment condition if and only if I ( z, d ) = 0 , where I ( z, d ) is the mutual information between z and d .Proof. • If I ( z, d ) = 0 , then p ( z | d ) = p ( z ) , whichmeans p ( z | d ) is invariant w.r.t. d .• If p ( z | d ) is invariant w.r.t. d , then ∀ z, d : p ( z ) = (cid:90) p ( z | d (cid:48) ) p ( d (cid:48) ) d d (cid:48) = (cid:90) p ( z | d ) p ( d (cid:48) ) d d (cid:48) ( since p ( z | d (cid:48) ) = p ( z | d ) ∀ d (cid:48) )= p ( z | d ) (cid:90) p ( d (cid:48) ) d d (cid:48) = p ( z | d )= ⇒ I ( z, d ) = 0 (3) The question still remains that how we can learn a non-trivial domain invariant representation that satisfies both ofthe distribution alignment conditions. This will be discussedin the following subsection. To present our method, we will make some assumptionsabout the data distribution. Specifically, for any two do-mains d, d (cid:48) , we assume that there exists an invertible anddifferentiable function denoted by f d,d (cid:48) that transforms thedensity p ( x | y, d ) to p ( x (cid:48) | y, d (cid:48) ) ∀ y . Let f d,d (cid:48) be the inverseof f d (cid:48) ,d , i.e., f d (cid:48) ,d := ( f d,d (cid:48) ) − .Due to the invertibility and differentiability of f ’s, we canapply the change of variables theorem (Rudin, 2006; Bo-gachev, 2007). In particular, with x (cid:48) = f d,d (cid:48) ( x ) (and thus x = f d (cid:48) ,d ( x (cid:48) ) ), we have p ( x | y, d ) = p ( x (cid:48) | y, d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − (4)where J f d (cid:48) ,d ( x (cid:48) ) is the Jacobian matrix of the function f d (cid:48) ,d evaluated at x (cid:48) .Multiplying both sides of Eq 4 with p ( y | d ) = p ( y | d (cid:48) ) , weget p ( x, y | d ) = p ( x (cid:48) , y | d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − (5)and marginalizing both sides of the above equation over y gives us p ( x | d ) = p ( x (cid:48) | d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − (6)By using Eq 4 and Eq 6, we can prove the following theo-rem, which offers a way to learn a domain-invariant repre-sentation, given the transformation functions f ’s betweendomains. Theorem 2.
Given an invertible and differentiable function f d,d (cid:48) (with the inverse f d (cid:48) ,d ) that transforms the data densityfrom domain d to d (cid:48) (as described above). Assuming thatthe representation z satisfies: p ( z | x ) = p ( z | f d,d (cid:48) ( x )) , ∀ x (7) Then it aligns both the marginal and the conditional of thedata distribution for domain d and d (cid:48) .Proof. i) Marginal alignment: ∀ z we have: p ( z | d ) = (cid:90) p ( x | d ) p ( z | x ) d x = (cid:90) p ( f d (cid:48) ,d ( x (cid:48) ) | d ) p ( z | f d (cid:48) ,d ( x (cid:48) )) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) d x (cid:48) omain Invariant Representation Learning with Domain Density Transformations 𝑓 ",$ transforms data density from domain 1 to 2(with the inverse 𝑓 $," ) 𝑥 $ = 𝑓 ",$ (𝑥 " ) 𝑥 " 𝑥 $ 𝑧 𝑔 , ( 𝑥 " ) 𝑔 , ( 𝑥 $ ) Domain 1 Domain 2
Figure 3.
Domain density transformation . If we know the function f , that transforms the data density from domain 1 to domain 2,we can learn a domain invariant representation network g θ ( x ) by enforcing it to be invariant under f , , i.e., g θ ( x ) = g θ ( x ) for any x = f , ( x ) . (by applying variable substitution in multiple inte-gral: x (cid:48) = f d,d (cid:48) ( x ) ) = (cid:90) p ( x (cid:48) | d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − p ( z | x (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) d x (cid:48) (since p ( f d (cid:48) ,d ( x (cid:48) ) | d ) = p ( x (cid:48) | d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − due to Eq 6 and p ( z | f d (cid:48) ,d ( x (cid:48) )) = p ( z | x (cid:48) ) due to defini-tion of z in Eq 7) = (cid:90) p ( x (cid:48) | d (cid:48) ) p ( z | x (cid:48) ) d x (cid:48) = p ( z | d (cid:48) ) (8)ii) Conditional alignment: ∀ z, y we have: p ( z | y, d ) = (cid:90) p ( x | y, d ) p ( z | x ) d x = (cid:90) p ( f d (cid:48) ,d ( x (cid:48) ) | y, d ) p ( z | f d (cid:48) ,d ( x (cid:48) )) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) d x (cid:48) (by applying variable substitution in multiple inte-gral: x (cid:48) = f d,d (cid:48) ( x ) ) = (cid:90) p ( x (cid:48) | y, d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − p ( z | x (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) d x (cid:48) (since p ( f d (cid:48) ,d ( x (cid:48) ) | y, d ) = p ( x (cid:48) | y, d (cid:48) ) (cid:12)(cid:12)(cid:12) det J f d (cid:48) ,d ( x (cid:48) ) (cid:12)(cid:12)(cid:12) − due to Eq 4 and p ( z | f d (cid:48) ,d ( x (cid:48) )) = p ( z | x (cid:48) ) due to definition of z inEq 7) = (cid:90) p ( x (cid:48) | y, d (cid:48) ) p ( z | x (cid:48) ) d x (cid:48) = p ( z | y, d (cid:48) ) (9)Note that p ( y | z, d ) = p ( y, z | d ) p ( z | d ) = p ( y | d ) p ( z | y, d ) p ( z | d ) (10)Since p ( y | d ) = p ( y ) = p ( y | d (cid:48) ) , p ( z | y, d ) = p ( z | y, d (cid:48) ) and p ( z | d ) = p ( z | d (cid:48) ) , we have: p ( y | z, d ) = p ( y | d (cid:48) ) p ( z | y, d (cid:48) ) p ( z | d (cid:48) ) = p ( y | z, d (cid:48) ) (11)This theorem indicates that, if we can find the functions f ’s that transform the data densities among the domains,we can learn a domain-invariant representation z by en-couraging the representation to be invariant under all thetransformations f ’s. This idea is illustrated in Figure 3. Wetherefore can use the following learning objective to learn adomain-invariant representation z = g θ ( x ) : E d (cid:2) E p ( x,y | d ) (cid:2) l ( y, g θ ( x )) + E d (cid:48) [ || g θ ( x ) − g θ ( f d,d (cid:48) ( x )) || ] (cid:3)(cid:3) (12)where l ( y, g θ ( x )) is the prediction loss of a network that pre-dicts y given z = g θ ( x ) , and the second term is to enforcethe invariant condition in Eq 7. omain Invariant Representation Learning with Domain Density Transformations Assume that we have a set of K sources domain D s = { d , d , ..., d K } , the objective function in Eq. 12 becomes: E d,d (cid:48) ∈ D s ,p ( x,y | d ) (cid:2) l ( y, g θ ( x )) + || g θ ( x ) − g θ ( f d,d (cid:48) ( x )) || (cid:3) (13)In the next section, we show how one can incorporate thisidea into real-world domain generalization problems withgenerative adversarial networks.
4. Domain Generalization with GenerativeAdversarial Networks
In practice, we will learn the functions f ’s that transformthe data distributions between domains and one can useseveral generative modeling frameworks, e.g., normalizingflows (Grover et al., 2020) or GANs (Zhu et al., 2017; Choiet al., 2018; 2020) to learn such functions. One advantageof normalizing flows is that this transformation is naturallyinvertible by design of the neural network. In addition, thedeterminant of the Jacobian of that transformation can beefficiently computed. However, due to the fact that we donot need access to the Jacobian when the training processof the generative model is completed, we propose the useof GANs to inherit its rich network capacity. In particular,we use the StarGAN (Choi et al., 2018) model, which isdesigned for image domain transformations.The goal of StarGAN is to learn a unified network G thattransforms the data density among multiple domains. Inparticular, the network G ( x, d, d (cid:48) ) (i.e., G is conditioned onthe image x and the two different domains d, d (cid:48) ) transformsan image x from domain d to domain d (cid:48) . Different fromthe original StarGAN model that only takes the image x and the desired destination domain d (cid:48) as its input, in ourimplementation, we feed both the original domain d anddesired destination domain d (cid:48) together with the originalimage x to the generator G .The generator’s goal is to fool a discriminator D into think-ing that the transformed image belongs to the destination do-main d (cid:48) . In other words, the equilibrium state of StarGAN,in which G completely fools D , is when G successfullytransforms the data density of the original domain to thatof the destination domain. After training, we use G ( ., d, d (cid:48) ) as the function f d,d (cid:48) ( . ) described in the previous sectionand perform the representation learning via the objectivefunction in Eq 13.Three important loss functions of the StarGAN architectureare:• Domain classification loss L cls that encourages thegenerator G to generate images that correctly belongsto the desired destination domain d (cid:48) . • The adversarial loss L adv that is the classification lossof a discriminator D that tries to distinguish betweenreal images and the fake images generated by G. Theequilibrium state of StarGAN is when G completelyfools D , which means the distribution of the generatedimages (via G ( x, d, d (cid:48) ) , x ∼ p ( x | d ) ) becomes the dis-tribution of the real images of the destination domain p ( x (cid:48) | d (cid:48) ) . This is our objective, i.e., to learn a functionthat transforms domains’ densities.• Reconstruction loss L rec = E x,d,d (cid:48) [ || x − G ( x (cid:48) , d (cid:48) , d ) || ] where x (cid:48) = G ( x, d, d (cid:48) ) to ensurethat the transformations preserve the image’s content.Note that this also aligns with our interest since wewant G ( ., d (cid:48) , d ) to be the inverse of G ( ., d, d (cid:48) ) , whichwill minimize L rec to zero.We can enforce the generator G to transform the data distri-bution within the class y (e.g., p ( x | y, d ) to p ( x (cid:48) | y, d (cid:48) ) ∀ y )by sampling each minibatch with data from the same class y , so that the discriminator will distinguish the transformedimages with the real images from class y and domain d (cid:48) .However, we found that this constraint can be relaxed inpractice, and the generator almost always transforms theimage within the original class y .As mentioned earlier, after training the StarGAN model,we can use the generator G ( ., d, d (cid:48) ) as our f d,d (cid:48) ( . ) functionand learn a domain-invariant representation via the learningobjective in Eq 13. We name this implementation of ourmethod DIR-GAN (domain-invariant representation learn-ing with generative adversarial networks).
5. Experiments
To evaluate our method, we perform experiments in threedatasets that are commonly used in the literature for domaingeneralization.
Rotated MNIST.
In this dataset by Ghifary et al.(2015), 1,000 MNIST images (100 per class) (LeCun &Cortes, 2010) are chosen to form the first domain (de-noted M ), then rotations of ◦ , ◦ , ◦ , ◦ and ◦ are applied to create five additional domains, denoted M , M , M , M and M . The task is classifica-tion with ten classes (digits 0 to 9). PACS (Li et al., 2017) contains 9,991 images from fourdifferent domains: art painting, cartoon, photo, sketch. Thetask is classification with seven classes.
OfficeHome (Venkateswara et al., 2017) has 15,500 im-ages of daily objects from four domains: art, clipart, productand real. There are 65 classes in this classification dataset. omain Invariant Representation Learning with Domain Density Transformations
Table 1.
Rotated Mnist leave-one-domain-out experiment. Reported numbers are mean accuracy and standard deviation among 5 runsDomainsModel M M M M M M AverageHIR (Wang et al., 2020) 90.34 99.75 99.40 96.17 99.25 91.26 96.03DIVA (Ilse et al., 2020) 93.5 99.3 99.1 99.2 99.3 93.0 97.2DGER (Zhao et al., 2020) 90.09 99.24 99.27 99.31 99.45 90.81 96.36DA (Ganin et al., 2016) 86.7 98.0 97.8 97.4 96.9 89.1 94.3LG (Shankar et al., 2018) 89.7 97.8 98.0 97.1 96.6 92.1 95.3HEX (Wang et al., 2019) 90.1 98.9 98.9 98.8 98.3 90.0 95.8ADV (Wang et al., 2019) 89.9 98.6 98.8 98.7 98.6 90.4 95.2DIR-GAN (ours) 97.2( ± ± ± ± ± ± Table 2.
PACS leave-one-domain-out experiment. Reported numbers are mean accuracy and standard deviation among 5 runsPACSModel Backbone Art Painting Cartoon Photo Sketch AverageDGER (Zhao et al., 2020) Resnet18 80.70 76.40 96.65 71.77 81.38JiGen (Carlucci et al., 2019) Resnet18 79.42 75.25 96.03 71.35 79.14MLDG (Li et al., 2018a) Resnet18 79.50 77.30 94.30 71.50 80.70MetaReg (Balaji et al., 2018) Resnet18 83.70 77.20 95.50 70.40 81.70CSD (Piratla et al., 2020) Resnet18 78.90 75.80 94.10 76.70 81.40DMG (Chattopadhyay et al., 2020) Resnet18 76.90 80.38 93.35 75.21 81.46DIR-GAN (ours) Resnet18 82.56( ± ± ± ± Table 3.
OfficeHome leave-one-domain-out experiment. Reported numbers are mean accuracy and standard deviation among 5 runsOfficeHomeModel Backbone Art ClipArt Product Real AverageD-SAM (D’Innocente & Caputo, 2018) Resnet18 58.03 44.37 69.22 71.45 60.77JiGen (Carlucci et al., 2019) Resnet18 53.04 47.51 71.47 72.79 61.20DIR-GAN (ours) Resnet18 56.69( ± ± ± ± For all datasets, we perform “leave-one-domain-out” exper-iments, where we choose one domain as the target domain,train the model on all remaining domains and evaluate iton the chosen domain. Following standard practice, we use90% of available data as training data and 10% as validationdata, except for the Rotated MNIST experiment where wedo not use a validation set and just report the performanceof the last epoch.For the
Rotated MNIST dataset, we use a network of two3x3 convolutional layers and a fully connected layer as therepresentation network g θ to get a representation z of 64dimensions. A single linear layer is then used to map therepresentation z to the ten output classes. This architectureis the deterministic version of the network used by Ilse et al.(2020). We train our network for 500 epochs with the Adamoptimizer (Kingma & Ba, 2014), using the learning rate0.001 and minibatch size 64, and report performance on the test domain after the last epoch.For the PACS and
OfficeHome datasets, we use a Resnet18(He et al., 2016) network as the representation network g θ .As a standard practice, the Resnet18 backbone is pre-trainedon ImageNet. We replace the last fully connected layer ofthe Resnet with a linear layer of dimensions (512, 256) sothat our representation has 256 dimensions. As with theRotated MNIST experiment, we use a single layer to mapfrom the representation z to the output. We train the networkfor 100 epochs with plain stochastic gradient descent (SGD)using learning rate 0.001, momentum 0.9, minibatch size64, and weight decay 0.001. Data augmentation is alsostandard practice for real-world computer vision datasetslike PACS and OfficeHome, and during the training weaugment our data as follows: crops of random size andaspect ratio, resizing to 224 × 224 pixels, random horizontalflips, random color jitter, randomly converting the imagetile to grayscale with 10% probability, and normalizationusing the ImageNet channel means and standard deviations. omain Invariant Representation Learning with Domain Density Transformations Figure 4.
Visualization of the representation space . Each point indicates a representation z of an image x in the two dimensional spaceand its color indicates the label y . Two left figures are for our method DIR-GAN and two right figures are for the naive model DeepAll. The StarGAN (Choi et al., 2018) model implementationis taken from the authors’ original source code with nosignificant modifications. For each set of source domains,we train the StarGAN model for 100,000 iterations with aminibatch of 16 images per iteration.The code for all of our experiments will be released forreproducibility. Please also refer to the source code for anyother architecture and implementation details.
Table 1 shows the perfor-mance of our model on the Rotated MNIST dataset. Themain baselines we consider in this experiment are HIR(Wang et al., 2020), DIVA (Ilse et al., 2020) and DGER(Zhao et al., 2020), which are domain invariance basedmethods. Our method recognizably outperforms those, il-lustrating the effectiveness of our method on learning adomain-invariant representation over the existing works. Wealso include other best-performing models for this dataset inthe second half of the table. To the best of our knowledge,we set a new state-of-the-art performance on this RotatedMNIST dataset.We further analyze the distribution of the representation z by performing principal component analysis to reduce thedimension of z from 64 to two principal components. Wevisualize the representation space for two domains M and M , with each point indicating the representation z of animage x in the two-dimensional space and its color indicat-ing the label y . Figures 4a and 4b show the representationspace of our method (in domains M and M respec-tively). It is clear that both the marginal (judged by the gen-eral distribution of the points) and the conditional (judgedby the positions of colors) are relatively aligned. Meanwhile,Figures 4c and 4d show the representation space with naivetraining (in domains M and M respectively), showingthe misalignment in the marginal distribution (judged by the general distribution of the points) and the conditionaldistribution (for example, the distributions of blue pointsand green points). PACS and OfficeHome.
To the best of our knowledge,domain invariant representation learning methods have notbeen applied widely and successfully for real-world com-puter vision datasets (e.g., PACS and OfficeHome) withvery deep neural networks such as Resnet, so the only rel-evant baseline to ours is DGER (Zhao et al., 2020) for thePACS experiment. Therefore, we include more baselinesfrom other approaches (e.g., meta-learning based or domain-specificity based methods) for comparison. Table 2 and3 show that DIR-GAN outperforms DGER significantlyand achieves competitive performance compared to otherstate-of-the-art baselines.
6. Conclusion
To conclude, in this work we propose a theoreticallygrounded approach to learn a domain-invariant represen-tation for the domain generalization problem by using do-main transformation functions. We also provide some in-sights into domain-invariant representation learning withseveral theoretical observations. We then introduce an im-plementation for our method in practice with the domaintransformations learned by a StarGAN architecture and em-pirically show that our approach outperforms other domain-invariance-based methods. Our method also achieves com-petitive results on several datasets when compared to otherstate-of-the-art models. In the future, we plan to incor-porate our method into meta-learning based and domain-specificity based approaches for improved performance.We also plan to extend the domain-invariant representationlearning framework to the more challenging scenarios, forexample, where domain information is not available (i.e., wehave a dataset pooled from multiple source domains but donot know the domain identification of each data instance). omain Invariant Representation Learning with Domain Density Transformations
References
Akuzawa, K., Iwasawa, Y., and Matsuo, Y. Adversar-ial invariant feature learning with accuracy constraintfor domain generalization. In
Joint European Confer-ence on Machine Learning and Knowledge Discovery inDatabases , pp. 315–331. Springer, 2019.Arjovsky, M., Bottou, L., Gulrajani, I., and Lopez-Paz, D. Invariant risk minimization. arXiv preprintarXiv:1907.02893 , 2019.Balaji, Y., Sankaranarayanan, S., and Chellappa, R. Metareg:Towards domain generalization using meta-regularization.
Advances in Neural Information Processing Systems , 31:998–1008, 2018.Behl, H., Baydin, A. G., and Torr, P. H. Alpha maml:Adaptive model-agnostic meta-learning. In , 2019.Bogachev, V. I.
Measure theory , volume 1. Springer Science& Business Media, 2007.Carlucci, F. M., D’Innocente, A., Bucci, S., Caputo, B., andTommasi, T. Domain generalization by solving jigsawpuzzles. In
Proceedings of the IEEE/CVF Conferenceon Computer Vision and Pattern Recognition , pp. 2229–2238, 2019.Chattopadhyay, P., Balaji, Y., and Hoffman, J. Learningto balance specificity and invariance for in and out ofdomain generalization. In
European Conference on Com-puter Vision , pp. 301–318. Springer, 2020.Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. Asimple framework for contrastive learning of visual rep-resentations. In
International conference on machinelearning , pp. 1597–1607. PMLR, 2020a.Chen, T., Kornblith, S., Swersky, K., Norouzi, M., andHinton, G. Big self-supervised models are strong semi-supervised learners. arXiv preprint arXiv:2006.10029 ,2020b.Choi, Y., Choi, M., Kim, M., Ha, J.-W., Kim, S., and Choo,J. Stargan: Unified generative adversarial networks formulti-domain image-to-image translation. In
Proceedingsof the IEEE conference on computer vision and patternrecognition , pp. 8789–8797, 2018.Choi, Y., Uh, Y., Yoo, J., and Ha, J.-W. Stargan v2: Diverseimage synthesis for multiple domains. In
Proceedingsof the IEEE/CVF Conference on Computer Vision andPattern Recognition , pp. 8188–8197, 2020. Combes, R. T. d., Zhao, H., Wang, Y.-X., and Gordon,G. Domain adaptation with conditional distributionmatching and generalized label shift. arXiv preprintarXiv:2003.04475 , 2020.Ding, Z. and Fu, Y. Deep domain generalization with struc-tured low-rank constraint.
IEEE Transactions on ImageProcessing , 27(1):304–313, 2017.Du, Y., Xu, J., Xiong, H., Qiu, Q., Zhen, X., Snoek, C. G.,and Shao, L. Learning to learn with variational informa-tion bottleneck for domain generalization. In
EuropeanConference on Computer Vision , pp. 200–216. Springer,2020.D’Innocente, A. and Caputo, B. Domain generalization withdomain-specific aggregation modules. In
German Con-ference on Pattern Recognition , pp. 187–198. Springer,2018.Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle,H., Laviolette, F., Marchand, M., and Lempitsky, V.Domain-adversarial training of neural networks.
TheJournal of Machine Learning Research , 17(1):2096–2030,2016.Ghifary, M., Kleijn, W. B., Zhang, M., and Balduzzi, D. Do-main generalization for object recognition with multi-taskautoencoders. In
Proceedings of the IEEE InternationalConference on Computer Vision , pp. 2551–2559, 2015.Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B.,Warde-Farley, D., Ozair, S., Courville, A., and Ben-gio, Y. Generative adversarial networks. arXiv preprintarXiv:1406.2661 , 2014.Grover, A., Chute, C., Shu, R., Cao, Z., and Ermon, S. Align-flow: Cycle consistent learning from multiple domainsvia normalizing flows. In
Proceedings of the AAAI Confer-ence on Artificial Intelligence , volume 34, pp. 4028–4035,2020.He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn-ing for image recognition. In
Proceedings of the IEEEConference on Computer Vision and Pattern Recognition ,pp. 770–778, 2016.Ilse, M., Tomczak, J. M., Louizos, C., and Welling, M. Diva:Domain invariant variational autoencoders. In
MedicalImaging with Deep Learning , pp. 322–348. PMLR, 2020.Khosla, A., Zhou, T., Malisiewicz, T., Efros, A. A., andTorralba, A. Undoing the damage of dataset bias. In
European Conference on Computer Vision , pp. 158–171.Springer, 2012.Kingma, D. P. and Ba, J. Adam: A method for stochasticoptimization. arXiv preprint arXiv:1412.6980 , 2014. omain Invariant Representation Learning with Domain Density Transformations
LeCun, Y. and Cortes, C. MNIST handwritten digitdatabase. 2010. URL http://yann.lecun.com/exdb/mnist/ .Li, D., Yang, Y., Song, Y.-Z., and Hospedales, T. Deeper,broader and artier domain generalization. In
InternationalConference on Computer Vision , 2017.Li, D., Yang, Y., Song, Y.-Z., and Hospedales, T. Learningto generalize: Meta-learning for domain generalization.In
Proceedings of the AAAI Conference on Artificial In-telligence , volume 32, 2018a.Li, Y., Gong, M., Tian, X., Liu, T., and Tao, D. Domaingeneralization via conditional invariant representations.In
Proceedings of the AAAI Conference on Artificial In-telligence , volume 32, 2018b.Li, Y., Tian, X., Gong, M., Liu, Y., Liu, T., Zhang, K.,and Tao, D. Deep domain generalization via conditionalinvariant adversarial networks. In
Proceedings of theEuropean Conference on Computer Vision (ECCV) , pp.624–639, 2018c.Misra, I. and Maaten, L. v. d. Self-supervised learning ofpretext-invariant representations. In
Proceedings of theIEEE/CVF Conference on Computer Vision and PatternRecognition , pp. 6707–6717, 2020.Muandet, K., Balduzzi, D., and Schölkopf, B. Domaingeneralization via invariant feature representation. In
International Conference on Machine Learning , pp. 10–18. PMLR, 2013.Piratla, V., Netrapalli, P., and Sarawagi, S. Efficient domaingeneralization via common-specific low-rank decomposi-tion. In
International Conference on Machine Learning ,pp. 7728–7738. PMLR, 2020.Rudin, W.
Real and complex analysis . Tata McGraw-hilleducation, 2006.Seo, S., Suh, Y., Kim, D., Han, J., and Han, B. Learningto optimize domain specific normalization for domaingeneralization. arXiv preprint arXiv:1907.04275 , 3(6):7,2019.Shankar, S., Piratla, V., Chakrabarti, S., Chaudhuri, S.,Jyothi, P., and Sarawagi, S. Generalizing acrossdomains via cross-gradient training. arXiv preprintarXiv:1804.10745 , 2018.Tanwani, A. K. Domain-invariant representation learningfor sim-to-real transfer. arXiv preprint arXiv:2011.07589 ,2020.Venkateswara, H., Eusebio, J., Chakraborty, S., and Pan-chanathan, S. Deep hashing network for unsupervised domain adaptation. In
Proceedings of the IEEE Confer-ence on Computer Vision and Pattern Recognition , pp.5018–5027, 2017.Wang, H., He, Z., Lipton, Z. C., and Xing, E. P. Learningrobust representations by projecting superficial statisticsout. arXiv preprint arXiv:1903.06256 , 2019.Wang, Z., Loog, M., and van Gemert, J. Respecting domainrelations: Hypothesis invariance for domain generaliza-tion. arXiv preprint arXiv:2010.07591 , 2020.Zhang, Y., Liu, T., Long, M., and Jordan, M. Bridgingtheory and algorithm for domain adaptation. In
Interna-tional Conference on Machine Learning , pp. 7404–7413.PMLR, 2019.Zhao, H., Des Combes, R. T., Zhang, K., and Gordon, G. Onlearning invariant representations for domain adaptation.In
International Conference on Machine Learning , pp.7523–7532. PMLR, 2019.Zhao, S., Gong, M., Liu, T., Fu, H., and Tao, D. Domaingeneralization via entropy regularization.
Advances inNeural Information Processing Systems , 33, 2020.Zhu, J.-Y., Park, T., Isola, P., and Efros, A. A. Unpairedimage-to-image translation using cycle-consistent adver-sarial networks. In