Geometric Dataset Distances via Optimal Transport
GGeometric Dataset Distances via Optimal Transport
David Alvarez-Melis Nicolò Fusi Abstract
The notion of task similarity is at the core of vari-ous machine learning paradigms, such as domainadaptation and meta-learning. Current methodsto quantify it are often heuristic, make strongassumptions on the label sets across the tasks,and many are architecture-dependent, relying ontask-specific optimal parameters ( e. g., requiretraining a model on each dataset). In this workwe propose an alternative notion of distance be-tween datasets that (i) is model-agnostic, (ii) doesnot involve training, (iii) can compare datasetseven if their label sets are completely disjointand (iv) has solid theoretical footing. This dis-tance relies on optimal transport, which providesit with rich geometry awareness, interpretablecorrespondences and well-understood properties.Our results show that this novel distance providesmeaningful comparison of datasets, and corre-lates well with transfer learning hardness acrossvarious experimental settings and datasets.
1. Introduction
A key hallmark of machine learning practice is that la-beled data from the application of interest is usually scarce.For this reason, there is vast interest in methods that cancombine, adapt and transfer knowledge across datasetsand domains. Entire research areas are devoted to thesegoals, such as domain adaptation, transfer-learning andmeta-learning. A fundamental concept underlying all theseparadigms is the notion of distance (or more generally, sim-ilarity ) between datasets. For instance, transferring knowl-edge across similar domains should intuitively be easierthan across distant ones. Likewise, given a choice of vari-ous datasets to pretrain a model on, it would seem natural tochoose the one that is closest to the task of interest.Despite its evident usefulness and apparent simpleness, thenotion of distance between datasets is an elusive one, andquantifying it efficiently and in a principled manner re- Microsoft Research. Correspondence to: DavidAlvarez-Melis
2. Related Work
Discrepancy Distance
Various notions of (dis)similaritybetween data distributions have been proposed in the con-text of domain adaptation, such as the d A (Ben-David et al.,2007) and discrepancy distances (Mansour et al., 2009).These discrepancies depend on a loss function and hy-pothesis ( i. e., predictor) class, and quantify dissimilaritythrough a supremum over this function class. The lat-ter discrepancy in particular has proven remarkably usefulfor proving generalization bounds for adaptation (Cortes &Mohri, 2011), and while it can be estimated from samples,bounding the approximation quality relies on quantities likethe VC-dimension of the hypothesis class, which might notbe always known or easy to compute. Dataset Distance via Parameter Sensitivity
The Fisherinformation metric is a classic notion from information ge-ometry (Amari, 1985; Amari & Nagaoka, 2000) that char-acterizes a parametrized probability distribution locallythrough the sensitivity of its density to changes in the pa-rameters. In machine learning, it has been used to analyzeand improve optimization approaches (Amari, 1998) andto measure the capacity of neural networks (Liang et al.,2019). In recent work, Achille et al. (2019) use this no-tion to construct vector representations of tasks, which theythen use to define a notion of similarity between these.They show that this notion recovers taxonomic similaritiesand is useful in meta-learning to predict whether a certainfeature extractor will perform well in a new task. Whilethis notion shares with ours its agnosticism of the num-ber of classes and their semantics, it differs in the fact thatit relies on a probe network trained on a specific dataset,so its geometry is heavily influenced by the characteristicsof this network. Besides the Fisher information, a relatedinformation-theoretic notion of complexity that can be usedto characterize tasks is the Kolmogorov Structure Function(Li, 2006), which Achille et al. (2018) use to define a no-tion of reachability between tasks.
Optimal Transport-based distributional distances
Thegeneral idea of representing complex objects via distribu-tions, which are then compared through optimal transportdistances, is an active area of research. Also driven by the Despite its name, this discrepancy is not a distance in general. appeal of their closed-form Wasserstein distance, Muzel-lec & Cuturi (2018) propose to embed objects as ellipticaldistributions, which requires differentiating through thesedistances, and discuss various approximations to scale upthese computations. Frogner et al. (2019) extend this ideabut represent the embeddings as discrete measures ( i. e., point clouds) rather than Gaussian/Elliptical distributions.Both of these works focus on embedding and consideronly within-dataset comparisons. Also within this line ofwork, Delon & Desolneux (2019) introduce a Wasserstein-type distance between Gaussian mixture models. Theirapproach restricts the admissible transportation couplingsthemselves to be Gaussian mixture models, and does notdirectly model label-to-label similarity. More generally,the Gromov-Wasserstein distance (Mémoli, 2011) has beenproposed to compare collections across different domains(Mémoli, 2017; Alvarez-Melis & Jaakkola, 2018), albeitleveraging only features, not labels.
Hierarchical OT distances
The distance we propose canbe understood as a hierarchical OT distance, i. e., one wherethe ground metric itself is defined through an OT problem.This principle has been explored in other contexts before.For example, Yurochkin et al. (2019) use a hierarchical OTdistance for document similarity, defining a inner-level dis-tance between topics and a outer-level distance betweendocuments using OT. (Dukler et al., 2019) on the otherhand use a nested Wasserstein distance as a loss for gen-erative model training, motivated by the observation thatthe Wasserstein distance is better suited to comparing im-ages than the usual pixel-wise L metric used as groundmetric. Both the goal, and the actual metric, used by theseapproaches differs from ours. Optimal Transport for Domain Adaptation
Using la-bel information to guide the optimal transport problem to-wards class-coherent matches has been explored before, e. g., by enforcing group-norm penalties (Courty et al.,2017) or through submodular cost functions (Alvarez-Melis et al., 2018). These works are focused on the unsu-pervised domain adaptation setting, so their proposed mod-ifications to the OT objective use only label informationfrom one of the two domains, and even then, do so withoutexplicitly defining a metric between these. Furthermore,they do not lead to proper distances, and these works dealwith a single static pair of tasks, so they lack analysis of thedistance across multiple source and target datasets.
3. Background on Optimal Transport
Optimal transport (OT) is a powerful and principled ap-proach to compare probability distributions, with deeproots in statistics, computer science and applied mathemat-ics (Villani, 2003; 2008). Among many desirable proper-ties, these distances leverage the geometry of the underly-2 eometric Dataset Distances via Optimal Transport ing space, making them ideal for comparing distributions,shapes and point clouds (Peyré & Cuturi, 2019).The OT problem considers a complete and separable met-ric space X , along with probability measures α ∈ P ( X ) and β ∈ P ( X ) . These can be continuous or discretemeasures, the latter often used in practice as empiri-cal approximations of the former whenever working inthe finite-sample regime. The Kantorovich formulation(Kantorovitch, 1942) of the transportation problem reads:OT ( α, β ) (cid:44) min π ∈ Π( α,β ) (cid:90) X ×X c ( x, y ) d π ( x, y ) , (1)where c ( · , · ) : X × X → R + is a cost function (the“ground” cost), and the set of couplings Π( α, β ) consists ofjoint probability distributions over the product space X ×X with marginals α and β , that is, Π( α, β ) (cid:44) { π ∈ P ( X×X ) | P π = α, P π = β } . (2)Whenever X is equipped with a metric d X , it is natural touse it as ground cost, e. g., c ( x, y ) = d X ( x, y ) p for some p ≥ . In such case, W p ( α, β ) (cid:44) OT ( α, β ) /p is called the p -Wasserstein distance. The case p = 1 is also known asthe Earth Mover’s Distance (Rubner et al., 2000).The measures α and β are rarely known in practice.Instead, one has access to finite samples { x ( i ) } ∈X , { y ( j ) } ∈ X . In that case, one can construct discretemeasures α = (cid:80) ni =1 a i δ x ( i ) and β = (cid:80) mi =1 b i δ y ( j ) , where a , b are vectors in the probability simplex, and the pair-wise costs can be compactly represented as an n × m ma-trix C , i. e., C ij = c ( x ( i ) , y ( j ) ) . In this case, Equation(1) becomes a linear program. Solving this problem scalescubically on the sample sizes, which is often prohibitive inpractice. Adding an entropy regularization, namelyOT (cid:15) ( α, β ) (cid:44) min π ∈ Π( α,β ) (cid:90) X c ( x, y ) d π ( x, y ) + (cid:15) H ( π | α ⊗ β ) , (3)where H ( π | α ⊗ β ) = (cid:82) log(d π/ d α d β ) d π is the rela-tive entropy, leads to a problem that can be solved muchmore efficiently (Cuturi, 2013; Altschuler et al., 2017) andwith better sample complexity (Genevay et al., 2019) thanthe original one. The Sinkhorn divergence (Genevay et al.,2018), defined asSD ε ( α, β ) = OT ε ( α, β ) − OT ε ( α, α ) − OT ε ( β, β ) , has various desirable properties, e. g., it is positive, con-vex and metrizes the weak ∗ convergence of distributions(Feydy et al., 2019).In the discrete case, problem (3) can be solved with theSinkhorn algorithm (Cuturi, 2013; Peyré & Cuturi, 2019),a matrix-scaling procedure which iteratively updates u ← a (cid:11) Kv and v ← b (cid:11) K (cid:62) u , where K (cid:44) exp {− (cid:15) C } andthe division (cid:11) and exponential are entry-wise. F e a t u r e Sp a ce D A D B D A D B OT (feats. only) OTDDMethod d ( D A , D B ) OT (feats. only) OTDDMethod
Figure 1.
The importance of labels : the second pair of datasetsare much closer than the first under the usual (label-agnostic) OTdistance, while the opposite is true for our (label-aware) distance.
4. Optimal Transport between Datasets
The definition of dataset is notoriously inconsistent acrossthe machine learning literature, sometimes referring onlyto features or both features and labels. In the context ofsupervised learning, where the ultimate goal is to estimatepredictors f : X → Y (or conditional distributions P ( y | x ) ), we define a dataset D as a set of feature-label pairs ( x, y ) ∈ X × Y over a certain feature space X and label set Y . For simplicity, we will use z (cid:44) ( x, y ) to denote thesepairs, and Z (cid:44) X × Y for their underlying space.Henceforth, we focus on the case of classification, so Y shall be a finite set. We consider two datasets D A and D B ,and assume, for simplicity, that their feature spaces havethe same dimensionality, but will discuss how to relax thisassumption later on. On the other hand, we make no as-sumptions on the label sets Y A and Y B whatsoever. In par-ticular, the classes these encode could be partially overlap-ping or related ( e. g., IMAGENET and
CIFAR -10 ) or com-pletely disjoint ( e. g.,
CIFAR -10 and
MNIST ). Althoughnot a formal assumption of our approach, it will be usefulto think of the samples in these two datasets as being drawnfrom joint distributions P A ( x, y ) and P B ( x, y ) .Given D A = { ( x ( i ) A , y ( i ) A ) } ni =1 ∼ P A ( x, y ) and D B = { ( x ( j ) B , y ( j ) B ) } mj =1 ∼ P B ( x, y ) , our goal is to define a dis-tance d ( D A , D B ) that depends exclusively on the informa-tion contained in these datasets. The probabilistic inter-pretation of these collections suggests a simple-yet-provenapproach: comparing these datasets by means of a statis-tical divergence on their joint distributions. Among manysuch notions, optimal transport stands out because of vari-ous characteristics described in Section 3: its direct use of3 eometric Dataset Distances via Optimal Transport A B C D E
Label-to-label distance −
10 0 10 20 30 − − − . . . . . . . . . . . . . . . . . . . . . . . . . . . Samples and Optimal Coupling π ∗ AA BB CC DD EE 0011223344
Optimal Coupling π ∗ Figure 2.
Our approach represents labels as distributions over features and computes Wasserstein distances between them (left). Com-bined with the usual metric between features, this yields a transportation cost between datasets. The optimal transport problem thencharacterizes the distance between them as the minimal possible cost of coupling them (optimal coupling π ∗ shown on the right). the geometry of the underlying space, its characterizationof distance as correspondence (which will prove to havevarious useful applications in this context) and the vast the-ory, spanning three centuries, which it is built upon.Note, however, that direct application of OT to this settingis challenging. Indeed, problem (1) requires us to definea metric on the ground space, i. e., on Z = X × Y . Astraightforward way to do so would be via the individualmetrics in X and Y . Indeed, if d X , d Y are metrics on X and Y respectively, then d Z : X × Y → R + , given as: d Z ( z, z (cid:48) ) = (cid:0) d X ( x, x (cid:48) ) p + d Y ( y, y (cid:48) ) p (cid:1) /p , for p ≥ is a metric on Z .In most applications, d X is readily available, e. g., as theeuclidean distance in the feature space. On the other hand, d Y will rarely be so, particularly between labels from unre-lated label sets ( e. g., between cars in one image domainand and dogs in the other). If we had some prior knowl-edge of the label spaces, we could use it to define a no-tion of distance between pairs of labels. However, in thechallenging —but common— case where no such knowl-edge is available, the only information we have about thelabels is their occurrence in relation to the feature vectors x . Thus, we can take advantage of the fact that we have ameaningful metric in X and use it to compare labels. Ar-guably, the simplest such approach is as follows. Let usdefine N D ( y ) := { x ∈ X | ( x, y ) ∈ D} , i. e., N D ( y ) is theset of feature vectors with label y in dataset D , and let n y be its cardinality. With this, a distance between two labels y and y (cid:48) can be defined as the distance between the centroidsof their associated feature vector collections: d ( y, y (cid:48) ) = d X (cid:18) n y (cid:88) x ∈N D ( y ) x, n y (cid:48) (cid:88) x ∈N D ( y (cid:48) ) x (cid:19) . (4)Although appealing for its simplicity, representing the col-lections N D ( y ) only through their mean is too simplisticfor real datasets. Ideally, we would like to represent labels through the actual distribution over the feature space thatthey define, namely, by means of the map y (cid:55)→ α y ( X ) (cid:44) P ( X | Y = y ) , of which N D ( y ) can be understood asa finite sample. If we use this representation, defining adistance between labels boils down to choosing a statisti-cal divergence between their associated distributions. Oncemore, there are many possible choices for this distance, but—yet again— we argue that an OT is an ideal choice, sincethe notion of divergence we seek should: (i) provide a validmetric, (ii) be computable from finite samples, which iscrucial since the distributions α y are not available in ana-lytic form, and (iii) be able to deal with sparsely-supporteddistributions, all of which OT satisfies.The approach described so far grounds the comparison ofthe α y distributions to the feature space X , so we can sim-ply use d p X as the optimal transport cost, leading to a p-Wasserstein distance between labels: W pp ( α y , α y (cid:48) ) , and inturn, to the following distance between feature-label pairs: d Z (cid:0) ( x, y ) , ( x (cid:48) , y (cid:48) ) (cid:1) (cid:44) (cid:0) d X ( x, x (cid:48) ) p + W pp ( α y , α y (cid:48) ) (cid:1) p . (5)This gives us a point-wise notion of distance in Z , but weultimately seek a distance between distributions over thisspace, i. e., between joint distributions P ( x, y ) . Optimaltransport allows us to lift the ground ( i. e., point-wise) met-ric defined above into a distance between measures: d OT ( D A , D B ) = min π ∈ Π( α,β ) (cid:90) Z×Z d Z ( z, z (cid:48) ) π ( z, z (cid:48) ) . (6)The following result, an immediate consequence of the dis-cussion above, states that Eq. (6) is a proper distance – theOptimal Transport Dataset Distance ( OTDD ). Proposition 4.1. d OT ( D A , D B ) defines a valid metric on P ( X × P ( X )) the space of measures over feature andlabel-distribution pairs. eometric Dataset Distances via Optimal Transport It remains to describe how the distributions α y are to berepresented. A flexible non-parametric approach would beto treat the samples in N D ( y ) as support points of a uni-form empirical measure, i. e., α y = (cid:80) x ( i ) ∈N D ( y ) 1 n y δ x ( i ) ,as described in Section 3. The main downside of this ap-proach is that each evaluation of (5) involves solving anoptimization problem, which could be prohibitive. Indeed,in Section C.1 we show that for datasets of size n , this ap-proach has worst-case O ( n log n ) complexity.Instead, we propose an alternative representation of the α y as Gaussian distributions, which leads to a simple yettractable realization of the general dataset distance (6). For-mally, let us denote by ˆ µ y ∈ R d and ˆΣ y ∈ R d × d + thesample mean and covariance matrix associated with label y (through its feature neighborhood N D ( y )) , that is: ˆ µ y (cid:44) n y (cid:88) x ∈N D ( y ) x ; ˆΣ y (cid:44) n y (cid:88) x ∈N D ( y ) ( x − ˆ µ y ) (cid:62) ( x − ˆ µ y ) . With this, we model each label-feature distribution α y asa Gaussian Distribution N (ˆ µ y , ˆΣ y ) whose parameters arethe sample mean and covariance of N D ( y ) .The main motivation behind this choice is that the2-Wasserstein distance between Gaussian distributions N ( µ α , Σ α ) and N ( µ β , Σ β ) has as an analytic form:W ( α, β ) = (cid:107) µ α − µ β (cid:107) +tr(Σ α +Σ β − α Σ β Σ α ) ) (7)where Σ denotes the matrix square root. Furthermore,whenever Σ α and Σ β commute, this further simplifies toW ( α, β ) = (cid:107) µ α − µ β (cid:107) + (cid:107) Σ α − Σ β (cid:107) F . (8)When using Eq. (7) in the point-wise distance (5), we de-note the resulting distance (6) by d OT - N .Representing label-defined distributions as Gaussiansmight seem like a heuristic choice driven only by algebraicconvenience. However, the following result, a consequenceof a bound by Gelbrich (1990), shows that this approxima-tion lower-bounds the distance that would be obtained hadit been computed using the label distances on the true dis-tributions (regardless of their form): Proposition 4.2.
For any two datasets D A , D B , we have: d OT - N ( D A , D B ) ≤ d OT ( D A , D B ) (9) Furthermore, if the label distributions α y are all Gaussianor elliptical, these quantities are equal, i. e., d OT - N is exact. An illustration of the OTDD in a synthetic dataset summa-rizing its main characteristics is shown in Figure 2.
5. Computational Considerations
Since our goal in this work is to use the proposed datasetdistance as a tool for tasks like transfer learning in realistic( i. e., large) machine learning datasets, scalability is cru-cial. Indeed, most compelling use cases of any notion ofdistance between datasets will involve computing it repeat-edly on very large samples.While estimation of Wasserstein —and more generally, op-timal transport— distances is known to be computationallyexpensive in general, in Section 3 we briefly discussed howentropy regularization can be used to trade-off accuracy forruntime. Recall that both the general and Gaussian versionsof the dataset distance proposed in Section 4 involve solv-ing optimal transport problems (though the latter, owing theclosed form solution of subproblem (7), only requires opti-mization for the global problem). Therefore, both of thesedistances benefit from approximate OT solvers.But further speed-ups are possible. For d OT - N , a simpleand fast implementation can be obtained if (i) the metric in X coincides with the ground metric in the transport prob-lem on Y , and (ii) all covariance matrices commute. While(ii) will rarely occur in practice, one could use a diago-nal approximation to the covariance, or with milder as-sumptions, simultaneous matrix diagonalization (De Lath-auwer, 2003). In either case, using the simplification in(8), the pointwise distance d ( z, z (cid:48) ) can be computed by cre-ating augmented representations of each dataset, wherebyeach pair ( x, y ) is represented as a stacked vector ˜ x :=[ x ; µ y ; vec (Σ / y )] for the corresponding label mean andcovariance. Then, (cid:107) ˜ x − ˜ x (cid:48) (cid:107) = d Z ( x, y ; x (cid:48) , y (cid:48) ) for d Z as defined in Eq. (5). Therefore, in this case the OTDD canbe immediately computed using an off-the-shelf OT solveron these augmented datasets. While this approach is ap-pealing computationally, here instead we focus on a exactversion that does not require diagonal or commuting co-variance approximations, and leave empirical evaluation ofthis approximate approach for future work.The steps we propose next are motivated by the observa-tion that, unlike traditional OT distances for which the costof computing pair-wise distance is negligible compared tothe complexity of the optimization routine, in our case thelatter dominates, since it involves computing multiple OTdistances itself. In order to speed up computation, we firstprecompute and store in memory all label-to-label pairwisedistances d ( α y , α y (cid:48) ) , and retrieve them on-demand duringthe optimization of the global OT problem.For d OT - N , computing the label-to-label distances d ( N (ˆ µ y , ˆΣ y ) , N (ˆ µ y (cid:48) , ˆΣ y (cid:48) )) is dominated by the cost ofcomputing matrix square roots, which if done exactlyinvolves a full eigendecomposition. Instead, it can becomputed approximately using the Newton-Schulz itera-tive method (Higham, 2008; Muzellec & Cuturi, 2018).5 eometric Dataset Distances via Optimal Transport Dataset: usps D a t a s e t : m n i s t Label-to-Label Distance
Dataset: usps D a t a s e t : m n i s t
00 11 22 33 44 55 66 77 88 99 00112233445566778899
Optimal Coupling ( ε = 0 . Dataset: usps D a t a s e t : m n i s t
00 11 22 33 44 55 66 77 88 99 00112233445566778899
Optimal Coupling ( ε = 0 . Figure 3.
Dataset Distance between
MNIST and
USPS . Left : The label Wasserstein distances —computed without knowledge of therelation between labels across domains— recover expected relations between classes in the two domains.
Center/Right : The optimalcoupling π ∗ for different regularization levels exhibits a block-diagonal structure, indicating class-coherent matches across domains. Besides runtime, loading all examples of a given class tomemory (to compute means and covariances) might beinfeasible for large datasets (especially if running on GPU),so we instead use a two-pass stable online batch algorithmto compute these statistics (Chan et al., 1983).The following result summarizes the time complexity ofour two distances and sheds light on the trade-off betweenprecision and efficiency they provide.
Theorem 5.1.
For datasets of size n and m , with p and q classes, dimension d , and maximum class size n , both d OT and d OT - N incur in a cost of O ( nm log(max { n, m } ) τ − ) for solving the global OT problem τ -approximately, whilethe worst-case complexity for computing the label-to-labelpairwise distances (5) is O (cid:0) nm ( d + n log n + d n ) (cid:1) for d OT and O (cid:0) nmd + pqd + d n ( p + q ) (cid:1) for d OT - N . In most practical applications, the cost of computing pair-wise distances will dominate, making d OT - N superior. Forexample, if n = m and the largest class size is O ( n ) ,this step becomes O ( n log n ) —prohibitive for all but toydatasets— for d OT but only O ( n d + d ) for d OT - N .
6. Experiments
A driving motivation for proposing a dataset distance wasto provide a learning-free criterion on which to select asource dataset for transfer learning. In this section, weput this hypothesis to test on a simple domain adaptationsetting on
MNIST (LeCun et al., 2010) and three of its ex-tensions:
FASHION - MNIST (Xiao et al., 2017),
KMNIST (Clanuwat et al., 2018) and the letters split of
EMNIST (Cohen et al., 2017), in addition to
USPS . All datasetsconsist of 10 classes, except
EMNIST , for which the se-lected split has 26 classes. Throughout this section, we use M N I S T E M N I S TF a s h i on M N I S T K M N I S T U S P S
MNISTEMNISTFashionMNISTKMNISTUSPS d i s t an c e s × Pairwise Distances: *NIST/USPS Datasets
Figure 4.
Pairwise OT Distances for *
NIST + USPS datasets. a simple LeNet-5 neural network (two convolutional layers,three fully conntected ones) with ReLU activations. Whencarrying out adaptation, we freeze the convolutional layersand fine-tune only the top three layers.We first compute all pairwise OTDD distances (Fig 4). Forthe example of d OT - N ( MNIST , USPS ) , Figure 3 illustratestwo key components of the computation of the distance:the label-to-label distances (left) and the optimal coupling π ∗ obtained for two choices of entropy regularization pa-rameter ε (center, right). The diagonal elements of the firstplot ( i. e., distances between corresponding digit classes)are overall relatively smaller than off-diagonal elements.Interestingly, the class of USPS appears remarkably farfrom all
MNIST digits under this metric. On the otherhand, most correspondences lie along the (block) diago-nal of π ∗ , which shows the dataset distance is able to inferclass-coherent correspondences across them.We test the robustness of the distance by computing it re-peatedly for varying sample sizes. The results (Fig. 9, Ap-pendix F) show that the distance converges towards a fixedvalue as sample sizes grow, but interestingly, small samplesizes for USPS lead to wider variability, suggesting that thisdataset itself is more heterogeneous than
MNIST .6 eometric Dataset Distances via Optimal Transport OT Dataset Distance R e l a t i v e D r o p i n T e s t E rr o r ( % ) E → FE → KE → M E → U F → EF → K F → MF → U K → E K → FK → MK → UM → E M → FM → KM → U Distance vs Adaptation: *NIST Datasets ρ : − . . Figure 5.
Dataset distance vs. adaptation for *NIST datasets (M:
MNIST , E:
EMNIST , K:
KMNIST , F:
FASHION - MNIST , U:
USPS ). The error bars correspond to ± Despite both consisting of digits,
MNIST and
USPS are notthe closest among these datasets according to the OTDD,as Figure 4 shows. The closest pair is instead (
MNIST , EMNIST ), while
FASHION - MNIST appears comparativelyfar from all others, particularly
MNIST .Next, we compare these distances against the transferabil-ity between datasets, i. e., the gain in performance fromusing a model pertrained on the source domain and fine-tuning it on the target domain. To make these numberscomparable across adaptation pairs which involve datasetsof very different hardness, we define the transferability T of a source domain D S to a target domain D T as the relativedecrease in classification error when doing adaptation com-pared to training only on the target domain, i. e., T ( D S → D T ) = 100 × error( D S → D T ) − error( D T )error( D T ) . We run the adaptation task 10 times with different randomseeds for each pair of datasets, and compare T against theirdistance. The strong significant correlation between these(Fig. 5) shows that the OTDD is highly predictive of trans-ferability across these datasets. In particular, EMNIST ledto the best adaptation to
MNIST , justifying the —initiallycounter-intuitive— value of the OTDD.
Data augmentation — i. e., applying carefully chosen trans-formations on a dataset to enhance its quality anddiversity— is another key aspect of transfer learning that
OT Dataset Distance − R e l a t i v e D r o p i n T e s t E rr o r ( % ) No TransformCrop: centerCrop: rand Affine: randRot: rand(-180 ° ,180 ° )Rot: 30 ° Rot: 60 ° Rot: 90 ° Rot: 120 ° Rot: 150 ° Rot: 180 ° Distance vs Adaptation: mnist +aug. → usps ρ : − . . × − Figure 6.
Dataset distance vs. adaptation between
MNIST withvarious transformations applied to it and
USPS . While croppingthe
MNIST digits leads to better adaptation, rotating them de-grades it, both in agreement with the dataset distance. has substantial empirical effect on the quality of the trans-ferred model yet lacks principled guidelines. Here, we in-vestigate if the OTDD could be used to compare and selectamong possible augmentations.For a fixed source-target dataset pair, we generate repli-cas of the source data with various transformations ap-plied to it, compute their distance to the target dataset, andcompare against the transferability as before. We presentresults for a small-scale (
MNIST → USPS ) and a larger-scale (Tiny-ImageNet → CIFAR -10) setting. The transfor-mations we use on
MNIST consist of rotations by a fixeddegree [30 ◦ , . . . , ◦ ] , random rotations ( − ◦ , ◦ ) ,random affine transformations, center-crops and randomcrops. For Tiny-ImageNet we randomly vary brightness,contrast, hue and saturation. The models use are respec-tively the LeNet-5 and a ResNet-50 (training details pro-vided in Appendix E).The results in both of these settings (Figures 6 and 7) show,again, a strong significant correlation between these two. Areader familiar with the MNIST and
USPS datasets will notbe surprised by the fact that cropping images from the for-mer leads to substantially better performance on the latter,while most rotations degrade transferability.
Natural Language Processing (NLP) is of the areas wherelarge-scale transfer learning has had the most profound im-pact over the past few years, in part driven by the avail-ability of off-the-shelf large language-models pretrained onmassive amounts of the data (Peters et al., 2018; Devlin7 eometric Dataset Distances via Optimal Transport
OT Dataset Distance . . . . . . . . C I F A R - V a li d a t i o n A cc u r a c y Distance vs Adaptation: ImageNet → CIFAR-10 ρ : − . . × − Figure 7.
Dataset distance vs. adaptation between Tiny-ImageNetwith various transformations (source) and
CIFAR -10 (target). et al., 2019; Radford et al., 2019).While natural language inherently lacks the fixed-size con-tinuous vector representation required by our framework tocompute pointwise distances, we can take advantage of pre-cisely these pretrained models to embed sentences in vectorspace, furnishing them with a rich geometry. In our exper-iments, we first embed every sentence of every dataset us-ing the (base)
BERT model (Devlin et al., 2019), and thencompute OTDD on these embedded datasets.Here, we focus on the problem of sentence classification,and consider the following datasets by Zhang et al. (2015): AG NEWS ( ag ), DBPEDIA ( db ), YELP REVIEWS with5-way classification ( yl ) and binary polarity ( yl + ) labelencodings, AMAZON REVIEWS with 5-way classification( am ) and binary polarity ( am + ) label encodings, and YA - HOO ANSWERS ( yh ). We provide details for all thesedatasets in the Appendix.As before, we simulate a challenging adaptation setting bykeeping only 100 examples per target class. For every pairof datasets, we first fine-tune the BERT model using the en-tirety of the source domain data, after which we fine-tuneand evaluate on the target domain. Figure 8 shows that theOT dataset distance is highly correlated with transferabilityin this setting too. Interestingly, adaptation often leads todrastic degradation of performance in this case, which sug-gests that off-the-shelf
BERT is on its own powerful andflexible enough to initialize many of these tasks, and there-fore choosing the wrong domain for initial training mightdestroy some of that information. Using the sentence_transfomers library. Available via the torchtext library.
10 20 30 40 50 60 70
OT Dataset Distance − − − − R e l a t i v e D r o p i n T e s t E rr o r ( % ) ag → dbag → yl+ag → yl5ag → am+ag → am5ag → yh db → ag db → yl+db → yl5db → am+db → am5 db → yhyl+ → agyl+ → dbyl+ → yl5 yl+ → am+yl+ → am5 yl+ → yh yl5 → ag yl5 → dbyl5 → yl+ yl5 → am+yl5 → am5 yl5 → yh am+ → agam+ → dbam+ → yl+ am+ → yl5am+ → am5am+ → yh am5 → ag am5 → dbam5 → yl+am5 → yl5am5 → am+am5 → yh yh → ag yh → dbyh → yl+yh → yl5 yh → am+yh → am5 Distance vs Adaptation: Text Classification ρ : − . . × − Figure 8.
Distance vs. adaptation for text classification datasets(see main text for key), with sentence embedding via
BERT .
7. Discussion
We have shown that the notion of distance between datasetsproposed in this work is scalable and flexible enough to beused in realistic transfer learning scenarios, all the while of-fering appealing theoretical properties, interpretable com-parisons and requiring minimal assumptions on the under-lying datasets.There are many natural extensions of this framework. Herewe assumed that the datasets where defined on featurespaces of the same dimension, but one could instead lever-age a relational notion such as the Gromov-Wassersteindistance (Mémoli, 2011) to compute the distance betweendatasets whose features and not directly comparable. Onthe other hand, our efficient implementation relies on mod-eling groups of points with the same label as Gaussian dis-tributions. This could naturally be extended to more gen-eral distributions for which the Wasserstein distance ei-ther has an analytic solution or at least can be computedefficiently, such as elliptic distributions (Muzellec & Cu-turi, 2018), Gaussian mixture models (Delon & Desolneux,2019), certain Gaussian Processes (Mallasto & Feragen,2017), or tree metrics (Le et al., 2019).In this work, we purposely excluded two key aspects of anylearning task from our notion of distance: the loss functionand the predictor function class. While we posit that it iscrucial to have a notion of distance that is independent ofthese choices, it is nevertheless appealing to ask whetherour distance could be extended to take those into account,ideally involving minimal training. Exploring different av-enues to inject such information into this framework willbe the focus of our future work.8 eometric Dataset Distances via Optimal Transport
References
Achille, A., Mbeng, G., and Soatto, S. Dynamics andreachability of learning tasks. October 2018.Achille, A., Lam, M., Tewari, R., Ravichandran, A., Maji,S., Fowlkes, C., Soatto, S., and Perona, P. Task2Vec:Task embedding for Meta-Learning. In
Proceedings ofthe IEEE International Conference on Computer Vision ,pp. 6430–6439, 2019.Altschuler, J., Niles-Weed, J., and Rigollet, P. Near-lineartime approximation algorithms for optimal transport viasinkhorn iteration. In Guyon, I., Luxburg, U. V., Ben-gio, S., Wallach, H., Fergus, R., Vishwanathan, S., andGarnett, R. (eds.),
Advances in Neural Information Pro-cessing Systems 30 , pp. 1964–1974. Curran Associates,Inc., 2017.Alvarez-Melis, D. and Jaakkola, T. Gromov-Wassersteinalignment of word embedding spaces. In
Proceedings ofthe 2018 Conference on Empirical Methods in NaturalLanguage Processing , pp. 1881–1890, 2018. doi: 10.18653/v1/D18-1214.Alvarez-Melis, D., Jaakkola, T. S., and Jegelka, S. Struc-tured optimal transport. In Amos Storkey And (ed.),
Pro-ceedings of the Twenty-First International Conferenceon Artificial Intelligence and Statistics , volume 84 of
Proceedings of Machine Learning Research , pp. 1771–1780. PMLR, 2018.Amari, S.-I.
Differential-Geometrical Methods in Statis-tics , volume 28 of
Lecture Notes in Statistics . SpringerNew York, New York, NY, 1985. ISBN 9780387960562,9781461250562. doi: 10.1007/978-1-4612-5056-2.Amari, S.-I. Natural gradient works efficiently in learning.
Neural Comput. , 10(2):251–276, February 1998. ISSN0899-7667. doi: 10.1162/089976698300017746.Amari, S.-I. and Nagaoka, H.
Methods of Informa-tion Geometry . Translations of Mathematical Mono-graphs. American Mathematical Society, 2000. ISBN9780821843024.Ben-David, S., Blitzer, J., Crammer, K., and Pereira, F.Analysis of representations for domain adaptation. InSchölkopf, B., Platt, J. C., and Hoffman, T. (eds.),
Ad-vances in Neural Information Processing Systems 19 , pp.137–144. MIT Press, 2007.Chan, T. F., Golub, G. H., and Leveque, R. J. Algo-rithms for computing the sample variance: Analysis andrecommendations.
Am. Stat. , 37(3):242–247, August1983. ISSN 0003-1305. doi: 10.1080/00031305.1983.10483115.Clanuwat, T., Bober-Irizar, M., Kitamoto, A., Lamb, A., Yamamoto, K., and Ha, D. Deep learning for classicaljapanese literature. December 2018.Cohen, G., Afshar, S., Tapson, J., and van Schaik, A.EMNIST: Extending MNIST to handwritten letters. In , pp. 2921–2926. IEEE, May 2017. doi:10.1109/IJCNN.2017.7966217.Cortes, C. and Mohri, M. Domain adaptation in re-gression. In
Algorithmic Learning Theory , pp. 308–323. Springer Berlin Heidelberg, 2011. doi: 10.1007/978-3-642-24412-4\_25.Courty, N., Flamary, R., Tuia, D., and Rakotomamonjy, A.Optimal transport for domain adaptation.
IEEE Trans.Pattern Anal. Mach. Intell. , 39(9):1853–1865, Septem-ber 2017. ISSN 0162-8828. doi: 10.1109/TPAMI.2016.2615921.Cuturi, M. Sinkhorn distances: Lightspeed computationof optimal transport. In Burges, C. J. C., Bottou, L.,Welling, M., Ghahramani, Z., and Weinberger, K. Q.(eds.),
Advances in Neural Information Processing Sys-tems 26 , pp. 2292–2300. Curran Associates, Inc., 2013.De Lathauwer, L. Simultaneous matrix diagonalization:the overcomplete case. In
Proc. of the 4th InternationalSymposium on ICA and Blind Signal Separation, Nara,Japan , volume 8122, pp. 825. kecl.ntt.co.jp, 2003.Delon, J. and Desolneux, A. A wasserstein-type distancein the space of gaussian mixture models. July 2019.Deng, J., Dong, W., Socher, R., Li, L., Kai Li, and LiFei-Fei. ImageNet: A large-scale hierarchical imagedatabase. In , pp. 248–255. IEEE, June2009. doi: 10.1109/CVPR.2009.5206848.Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K.BERT: Pre-training of deep bidirectional transformersfor language understanding. In
Proceedings of the 2019Conference of the North American Chapter of the As-sociation for Computational Linguistics: Human Lan-guage Technologies, Volume 1 (Long and Short Papers) ,pp. 4171–4186, 2019.Dukler, Y., Li, W., Lin, A., and Montufar, G. Wassersteinof Wasserstein loss for learning generative models. InChaudhuri, K. and Salakhutdinov, R. (eds.),
Proceedingsof the 36th International Conference on Machine Learn-ing , volume 97 of
Proceedings of Machine Learning Re-search , pp. 1716–1725, Long Beach, California, USA,2019. PMLR.Feydy, J., Séjourné, T., Vialard, F.-X., Amari, S.-I., Trouve,A., and Peyré, G. Interpolating between optimal trans-port and MMD using sinkhorn divergences. In Chaud-9 eometric Dataset Distances via Optimal Transport huri, K. and Sugiyama, M. (eds.),
Proceedings of the22nd International Conference on Artificial Intelligenceand Statistics , volume 89 of
Proceedings of MachineLearning Research , pp. 2681–2690. PMLR, 2019.Frogner, C., Mirzazadeh, F., and Solomon, J. Learningembeddings into entropic wasserstein spaces. In
Inter-national Conference on Learning Representations , May2019.Gelbrich, M. On a formula for the L2 wasserstein met-ric between measures on euclidean and hilbert spaces.
Math. Nachr. , 147(1):185–203, November 1990. ISSN0025-584X. doi: 10.1002/mana.19901470121.Genevay, A., Peyre, G., and Cuturi, M. Learning generativemodels with sinkhorn divergences. In Storkey, A. andPerez-Cruz, F. (eds.),
Proceedings of the Twenty-FirstInternational Conference on Artificial Intelligence andStatistics , volume 84 of
Proceedings of Machine Learn-ing Research , pp. 1608–1617, Playa Blanca, Lanzarote,Canary Islands, 2018. PMLR.Genevay, A., Chizat, L., Bach, F., Cuturi, M., and Peyré,G. Sample complexity of sinkhorn divergences. InChaudhuri, K. and Sugiyama, M. (eds.),
Proceedings ofMachine Learning Research , volume 89 of
Proceedingsof Machine Learning Research , pp. 1574–1583. PMLR,2019.Higham, N. J.
Functions of Matrices: Theory and Compu-tation . SIAM, January 2008. ISBN 9780898717778.Hull, J. J. A database for handwritten text recognition re-search.
IEEE Trans. Pattern Anal. Mach. Intell. , 16(5):550–554, May 1994. ISSN 0162-8828, 1939-3539. doi:10.1109/34.291440.Kantorovitch, L. On the translocation of masses.
Dokl.Akad. Nauk SSSR , 37(7-8):227–229, 1942. ISSN 0002-3264.Khodak, M., Balcan, M.-F. F., and Talwalkar, A. S. Adap-tive Gradient-Based Meta-Learning methods. In Wal-lach, H., Larochelle, H., Beygelzimer, A., d’Alché Buc,F., Fox, E., and Garnett, R. (eds.),
Advances in Neural In-formation Processing Systems 32 , pp. 5915–5926. Cur-ran Associates, Inc., 2019.Krizhevsky, A. and Hinton, G. Learning multiple layers offeatures from tiny images. 2009.Kuhn, D., Esfahani, P. M., Nguyen, V. A., andShafieezadeh-Abadeh, S. Wasserstein distributionallyrobust optimization: Theory and applications in machinelearning. August 2019.Le, T., Yamada, M., Fukumizu, K., and Cuturi, M. Tree-Sliced variants of wasserstein distances. In Wallach, H., Larochelle, H., Beygelzimer, A., \’e, Fox, E., and Gar-nett, R. (eds.),
Advances in Neural Information Process-ing Systems 32 , pp. 12283–12294. Curran Associates,Inc., 2019.LeCun, Y., Cortes, C., and Burges, C. J. MNIST handwrit-ten digit database. 2010.Leite, R. and Brazdil, P. Predicting relative performanceof classifiers from samples. In
Proceedings of the 22ndinternational conference on Machine learning , pp. 497–503. dl.acm.org, 2005.Li, L.
Data Complexity in Machine Learning and NovelClassification Algorithms . PhD thesis, California Insti-tute of Technology, 2006.Liang, T., Poggio, T., Rakhlin, A., and Stokes, J. Fisher-Rao metric, geometry, and complexity of neural net-works. In
Proceedings of the 22nd International Con-ference on Artificial Intelligence and Statistics . PMLR,2019.Mallasto, A. and Feragen, A. Learning from uncertaincurves: The 2-wasserstein metric for gaussian processes.In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H.,Fergus, R., Vishwanathan, S., and Garnett, R. (eds.),
Ad-vances in Neural Information Processing Systems 30 , pp.5660–5670. Curran Associates, Inc., 2017.Mansour, Y., Mohri, M., and Rostamizadeh, A. Domainadaptation: Learning bounds and algorithms. In
The22nd Conference on Learning Theory . arxiv.org, 2009.Mémoli, F. Gromov–Wasserstein distances and the met-ric approach to object matching.
Found. Comput. Math. ,11(4):417–487, August 2011. ISSN 1615-3375, 1615-3383. doi: 10.1007/s10208-011-9093-5.Mémoli, F. Distances between datasets. In Najman, L.and Romon, P. (eds.),
Modern Approaches to DiscreteCurvature , pp. 115–132. Springer International Publish-ing, Cham, 2017. ISBN 9783319580029. doi: 10.1007/978-3-319-58002-9\_3.Muzellec, B. and Cuturi, M. Generalizing point embed-dings using the wasserstein space of elliptical distribu-tions. In Bengio, S., Wallach, H., Larochelle, H., Grau-man, K., Cesa-Bianchi, N., and Garnett, R. (eds.),
Ad-vances in Neural Information Processing Systems 31 , pp.10237–10248. Curran Associates, Inc., 2018.Peters, M., Neumann, M., Iyyer, M., Gardner, M., Clark,C., Lee, K., and Zettlemoyer, L. Deep contextualizedword representations. In
Proceedings of the 2018 Con-ference of the North American Chapter of the Associa-tion for Computational Linguistics: Human LanguageTechnologies, Volume 1 (Long Papers) , pp. 2227–2237.aclweb.org, 2018.10 eometric Dataset Distances via Optimal Transport
Peyré, G. and Cuturi, M. Computational optimal trans-port.
Foundations and Trends R (cid:13) in Machine Learn-ing , 11(5-6):355–607, 2019. ISSN 1935-8237. doi:10.1561/2200000073.Radford, A., Wu, J., Amodei, D., Amodei, D., Clark, J.,Brundage, M., and Sutskever, I. Better language mod-els and their implications. OpenAI Blog https://openai.com/blog/better-language-models , 2019.Rubner, Y., Tomasi, C., and Guibas, L. J. The earth mover’sdistance as a metric for image retrieval.
Int. J. Comput.Vis. , 40(2):99–121, November 2000. ISSN 0920-5691,1573-1405. doi: 10.1023/A:1026543900054.Tran, A. T., Nguyen, C. V., and Hassner, T. Transferabilityand hardness of supervised classification tasks. August2019.Villani, C.
Topics in Optimal Transportation . AmericanMathematical Soc., 2003. ISBN 9780821833124.Villani, C.
Optimal transport, Old and New , volume338. Springer Science & Business Media, 2008. ISBN9783540710493.Xiao, H., Rasul, K., and Vollgraf, R. Fashion-MNIST: anovel image dataset for benchmarking machine learningalgorithms. August 2017.Yurochkin, M., Claici, S., Chien, E., Mirzazadeh, F., andSolomon, J. M. Hierarchical optimal transport for doc-ument representation. In Wallach, H., Larochelle, H.,Beygelzimer, A., d’Alché Buc, F., Fox, E., and Gar-nett, R. (eds.),
Advances in Neural Information Process-ing Systems 32 , pp. 1599–1609. Curran Associates, Inc.,2019.Zhang, X., Zhao, J., and LeCun, Y. Character-level con-volutional networks for text classification. In Cortes, C.,Lawrence, N. D., Lee, D. D., Sugiyama, M., and Gar-nett, R. (eds.),
Advances in Neural Information Process-ing Systems 28 , pp. 649–657. Curran Associates, Inc.,2015. 11 eometric Dataset Distances via Optimal Transport
A. Proof of Proposition 4.1
Whenever the cost function used in the of optimal transportproblem is a metric in a given space X , the optimal trans-port problem is a distance (the Wasserstein distance) on P ( X ) (Villani, 2008, Chapter 6). Therefore, it suffices toshow that the cost function d Z defined in Eq. (5) is indeed adistance. Clearly, it is symmetric because both d X and W p are. In addition, since both of these are distances: d Z ( z, z (cid:48) ) = 0 ⇔ d X ( x, x (cid:48) ) = 0 ∧ W p ( α y , α (cid:48) y ) = 0 ⇔ x = x (cid:48) , α y = α (cid:48) y ⇔ z = z (cid:48) Finally, we have that d Z ( z , z ) = (cid:0) d X ( x , x ) p + W p ( α y , α y ) p (cid:1) p ≤ (cid:0) d X ( x , x ) p + d X ( x , x ) p + W p ( α y , α y ) p + W p ( α y , α y ) p (cid:1) p = (cid:0) d Z ( z , z ) p + d Z ( z , z ) p (cid:1) p = d Z ( z , z ) + d Z ( z , z ) where the last step is an application of Minkowski’s in-equality. Hence, d Z satisfies the triangle inequality, andtherefore it is a metric on Z = X × P ( X ) . We thereforeconclude that the value of the optimal transport (6) that usesthis metric as a cost function is a distance itself. B. Proof of Proposition 4.2
Our proof relies directly on a well-known bound for the2-Wasserstein distance between distributions by (Gelbrich,1990):
Lemma B.1 (Gelbrich bound) . Suppose α, β ∈ P ( R d ) areany two measures with mean vectors µ α , µ β ∈ R d and co-variance matrices Σ α , Σ β ∈ S d + respectively. Then, W (cid:0) N ( µ α , Σ α ) , N ( µ β , Σ β )) ≤ W ( α, β ) (10) where W (cid:0) N ( µ α , Σ α ) , N ( µ β , Σ β )) is as in Eq. (7) . In the notation of Section 3, Lemma B.1 implies that forevery feature-label pairs z = ( x, y ) and z (cid:48) = ( x (cid:48) , y (cid:48) ) , wehave: d X ( x, x (cid:48) ) + W (cid:0) N ( µ y , Σ y ) , N ( µ y (cid:48) , Σ y (cid:48) )) ≤ d X ( x, x (cid:48) ) + W ( α y , α y (cid:48) ) (11)Therefore: (cid:90) d Z ( z, z (cid:48) ) d π ≤ (cid:90) d Z ( z, z (cid:48) ) d π (12) for every coupling π ∈ Π( α, β ) . In particular, for the min-imizing π ∗ , we obtain that d OT ( D A , D B ; N ) ≤ d OT ( D A , D B ) (13)Clearly, Gelbrich’s bound holds with equality when α and β are indeed Gaussian. More generally, equality is attainedfor elliptical distributions with the same density generator(Kuhn et al., 2019)). This immediately implies equality ofthe two quantities in equation (13) in that case. C. Time Complexity Analysis
For the analyses in this section, assume that D S and D T re-spectively have n and m labeled examples in R d and k s , k t classes. In addition, let N S D ( i ) := { x ∈ X | ( x, y = i ) ∈D} be the subset of examples in D S with label i , and defineanalogously N T D ( j ) . The denote the cardinalities of thesesubsets as n is (cid:44) |N ( i ) s | and analogously for n jt .Direct computation of the distance (5) involves two mainsteps:(i) computing pairwise pointwise distances (each requir-ing solution of a label-to-label OT sub-problem), and(ii) a global OT problem between the two samples.Step (ii) is identical for both the general distance d OT andits Gaussian approximation counterpart d OT - N , so we ana-lyze it first. This is an OT problem between two discretedistributions of size n and m , which can be solved exactlyin O (cid:0) ( n + m ) nm log( nm ) (cid:1) using interior point methods orOrlin’s algorithm for the uncapacitated min cost flow prob-lem (Peyré & Cuturi, 2019). Alternatively, it can be solved τ -approximately in O ( nm log(max { n, m } ) τ − ) time us-ing the Sinkhorn algorithm (Altschuler et al., 2017).We next analyze step (i) individually for the two OTDDversions. Combined, they provide a proof of Theo-rem 5.1. C.1. Pointwise distance computation for d OT Consider a single pair of points, ( x, y = i ) ∈ D A and ( x (cid:48) , y (cid:48) = j ) ∈ D B . Evaluating (cid:107) x − x (cid:48) (cid:107) has O ( d ) complexity, while W ( α y , β y (cid:48) ) is an n is × n jt OTproblem which itself requires computing a distance ma-trix (at cost O ( n is n jt d ) ), and then solving the OT prob-lem, which as discussed before, be done exactly in O (cid:0) ( n is + n jt ) n is n jt log( n is + n jt ) (cid:1) or τ -approximately in O ( n is n jt log(max { n is , n jt } ) τ − ) .For simplicity, let us denote n s = max i n is , and n t =max j n jt the size of the largest label cluster in each dataset,and n = max { n s , n t } the overall largest one. Using these,and combining all of the above, the overall worst case com-12 eometric Dataset Distances via Optimal Transport Dataset Input Dimension Number of Classes Train Examples Test Examples Source
USPS × ∗
10 7291 2007 (Hull, 1994)
MNIST ×
28 10 60 K K (LeCun et al., 2010)
EMNIST (letters) × K K (Cohen et al., 2017)
KMNIST ×
28 10 60 K K (Clanuwat et al., 2018)
FASHION - MNIST × K K (Xiao et al., 2017) T INY -I MAGE N ET × ‡
200 100 K K (Deng et al., 2009)
CIFAR -10 ×
32 10 50 K K (Krizhevsky & Hinton, 2009) AG - NEWS † K . K (Zhang et al., 2015) DB PEDIA †
14 560 K K (Zhang et al., 2015) Y ELP R EVIEW (Polarity) † K K (Zhang et al., 2015) Y ELP R EVIEW (Full Scale) † K K (Zhang et al., 2015) A MAZON R EVIEW (Polarity) † . M K (Zhang et al., 2015) A MAZON R EVIEW (Full Scale) † M K (Zhang et al., 2015) Y AHOO A NSWERS †
10 1.4M 60K (Zhang et al., 2015)
Table 1.
Summary of all the datasets used in this work. ∗ : we rescale the USPS digits to × for comparison to the *NIST datasets. ‡ : we rescale Tiny-ImageNet to × for comparison to CIFAR -10 . † : for text datasets, variable-length sentences are embedded tofixed-dimensional vectors using BERT . . plexity for the computation of the n × m pairwise distancescan be expressed as O (cid:0) nm ( d + n log n + d n ) (cid:1) , (14)which is what we wanted to show. C.2. Pointwise distance computation for d OT - N As before, consider a pair of points ( x, y = i ) ∈ D A and ( x (cid:48) , y (cid:48) = j ) ∈ D B whose cluster sizes are n is and n jt re-spectively. As mentioned in Section 5, for d OT - N we firstcompute all the per-class means and covariance matrices.This step is clearly dominated by latter, which is O ( d n is ) . Considering all labels from both datasets, this amounts to aworst-case complexity of O (cid:0) d ( k s n s + k t n t ) (cid:1) .Once the means and covariances have been computed, weprecompute all the k s × k t pair-wise label-to-label distancesW ( α y , β y (cid:48) ) using Eq. (7). This computation is dominatedby the matrix square roots. If done exactly, these involve afull eigendecomposition, at cost O ( d ) , so the total cost forthis step is O ( k s k t d ) .Finally, while computing the pairwise distance, we will in-cur in O ( nmd ) to obtain (cid:107) x = x (cid:48) (cid:107) . Putting all of thesetogether, and replacing n s , n t by n , we obtain a total costfor precomputing all the point-wise distances of: O ( nmd + k s k t d + d n ( k s + k t ) , which concludes the proof. technically, this would be O ( d ω n is ) where ω is the coefficientof matrix multiplication, but we take ω = 3 for simplicity. D. Dataset Details
Information about all the datasets used, including refer-ences, are provided in Table 1.
E. Optimization and Training Details
For the adaptation experiments on the *NIST datasets,we use a LeNet-5 architecture with ReLU nonlinearitiestrained for 20 epochs using ADAM with learning rate × − a nd weight decay × − It was fine-tuned for10 epochs on the target domain(s) using the same optimiza-tion parameters.For the Tiny-ImageNet to C IFAR -10 adaptation results,we use a ResNet-50 trained for 300 epochs using SGDwith learning rate 0.1 momentum 0.9 and weight decay × − It was fine-tuned for 30 epochs on the target do-main using SGD with same parameters except 0.01 learn-ing rate. We discard pairs for which the variance on adap-tation accuracy is beyond a certain threshold.For the text classification experiments, we use a pretrained
BERT architecture (the bert-base-uncased model ofthe transformers library). We first embed all sen-tences using this model. Then, for each pair of source/tar-get domains, we first fine-tune using ADAM with learningrate × − f or 10 epochs on the full source domain data,and the fine-tune on the restricted target domain data withthe same optimization parameters for 2 epochs.Our implementation of the OTDD relies on the pot and geomloss python packages. huggingface.co/transformers/ pot.readthedocs.io/en/stable/ eometric Dataset Distances via Optimal Transport F. Robustness of the Distance E n t r o p y R e g u l a r i z a t i o n MNIST:100 / USPS:100 MNIST:100 / USPS:500 MNIST:100 / USPS:1000 MNIST:100 / USPS:72910.0110.0 E n t r o p y R e g u l a r i z a t i o n MNIST:500 / USPS:100 MNIST:500 / USPS:500 MNIST:500 / USPS:1000 MNIST:500 / USPS:72910.0110.0 E n t r o p y R e g u l a r i z a t i o n MNIST:1000 / USPS:100 MNIST:1000 / USPS:500 MNIST:1000 / USPS:1000 MNIST:1000 / USPS:72910.0110.0 E n t r o p y R e g u l a r i z a t i o n MNIST:10000 / USPS:100 MNIST:10000 / USPS:500 MNIST:10000 / USPS:1000 MNIST:10000 / USPS:7291
Distance0.0110.0 E n t r o p y R e g u l a r i z a t i o n MNIST:50000 / USPS:100
DistanceMNIST:50000 / USPS:500
DistanceMNIST:50000 / USPS:1000
DistanceMNIST:50000 / USPS:7291
Figure 9.
Robustness Analysis : distances computed on subsets of varying size (rows:
MNIST , columns:
USPS ), over 10 randomrepetitions, for two values of the regularization parameter ε ..