FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning
FFeatMatch: Feature-Based Augmentationfor Semi-Supervised Learning
Chia-Wen Kuo † , Chih-Yao Ma † , Jia-Bin Huang ‡ , Zsolt Kira † , † Georgia Tech, ‡ Virginia Tech [email protected] , [email protected] , [email protected] , [email protected] Abstract.
Recent state-of-the-art semi-supervised learning (SSL) meth-ods use a combination of image-based transformations and consistencyregularization as core components. Such methods, however, are limited tosimple transformations such as traditional data augmentation or convexcombinations of two images. In this paper, we propose a novel learnedfeature-based refinement and augmentation method that produces a var-ied set of complex transformations. Importantly, these transformationsalso use information from both within-class and across-class prototypi-cal representations that we extract through clustering. We use featuresalready computed across iterations by storing them in a memory bank,obviating the need for significant extra computation. These transfor-mations, combined with traditional image-based augmentation, are thenused as part of the consistency-based regularization loss. We demonstratethat our method is comparable to current state of art for smaller datasets(CIFAR-10 and SVHN) while being able to scale up to larger datasetssuch as CIFAR-100 and mini-Imagenet where we achieve significant gainsover the state of art ( e.g., absolute 17.44% gain on mini-ImageNet). Wefurther test our method on DomainNet, demonstrating better robust-ness to out-of-domain unlabeled data, and perform rigorous ablationsand analysis to validate the method.
Keywords: semi-supervised learning, feature-based augmentation, con-sistency regularization
Driven by large-scale datasets such as ImageNet as well as computing resources,deep neural networks have achieved strong performance on a wide variety oftasks. Training these deep neural networks, however, requires millions of labeledexamples that are expensive to acquire and annotate. Consequently, numerousmethods have been developed for semi-supervised learning (SSL), where a largenumber of unlabeled examples are available alongside a smaller set of labeleddata. One branch of the most successful SSL methods [16,19,22,25,27,4,3] usesimage-based augmentation [35,8,13,6] to generate different transf ormations ofan input image, and consistency regularization to enforce invariant represen-tations across these transformations. While these methods have achieved greatsuccess, the data augmentation methods for generating different transformations a r X i v : . [ c s . C V ] J u l Kuo et al.
Concept �𝑥𝑥𝑥𝑥
Image-basedAugmentation 𝑓𝑓 �𝑥𝑥 𝑓𝑓 𝑥𝑥 𝑝𝑝 ( 𝑦𝑦 | 𝑓𝑓 �𝑥𝑥 ) 𝑝𝑝 ( 𝑦𝑦 | 𝑓𝑓 𝑥𝑥 ) Consistency Loss
EncoderEncoder f c l a y e r s o f t m a x f c l a y e r s o f t m a x (a) Image-Based Augmentation and Consistency Concept 𝑥𝑥 𝑓𝑓 𝑥𝑥 Prototypesper class 𝑔𝑔 𝑥𝑥 𝑝𝑝 ( 𝑦𝑦 | 𝑔𝑔 𝑥𝑥 ) 𝑝𝑝 ( 𝑦𝑦 | 𝑓𝑓 𝑥𝑥 ) Consistency Loss
Feature-basedAugmentation
Encoder f c l a y e r s o f t m a x f c l a y e r s o f t m a x (b) Feature-Based Augmentation and Consistency Fig. 1: Consistency regularization methods are the most successful methods forsemi-supervised learning. The main idea of these methods is to enforce consis-tency between the predictions of different transformations of an input image.( a ) Image-based augmentation method generate different views of an input im-age via data augmentation, which are limited to operations in the image spaceas well as operations within a single instance or simple convex combination oftwo instances. ( b ) We propose an additional learned feature-based augmentationthat operates in the abstract feature space. The learned feature refinement andaugmentation module is capable of leveraging information from other instances,within or outside of the same class.are limited to transformations in the image space and fail to leverage the knowl-edge of other instances in the dataset for diverse transformations.In this paper, we propose novel feature-based refinement and augmentationthat addresses the limitations of conventional image-based augmentation de-scribed above. Specifically, we propose a module that learns to refine and aug-ment input image features via soft-attention toward a small set of representativeprototypes extracted from the image features of other images in the dataset. Thecomparison between image-based augmentation and our proposed feature-basedrefinement and augmentation is shown in Fig. 1. Since the proposed module islearned and carried out in the feature space, diverse and abstract transforma-tions of input images can be applied, which we validate in Sec. 4.4. Our approachonly requires minimum computation via maintaining a memory bank and usingk-means clustering to extract prototypes.We demonstrate that adding our proposed feature-based augmentation alongwith conventional image-based augmentations, when used for consistency reg-ularization, achieves significant gains. We test our method on standard SSLdatasets such as SVHN and CIFAR-10, and show that our method, despite its eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 3 simplicity, compares favorably against state-of-art methods in all cases. Further,through testing on CIFAR-100 and mini-ImageNet, we show that our methodis scalable to larger datasets and outperformed the current best methods bysignificant margins. For example, we outperformed the closest state of the artby an absolute on mini-ImageNet with 4k labels. We also propose an-other realistic setting on DomainNet [21] to test the robustness of our proposedmethod under the case where the unlabeled samples are partially coming fromshifted domains, in which we improved over supervised baseline and over semi-supervised baseline when 50% unlabeled samples are all coming fromshifted domains. Finally, we conduct thorough ablations and thorough analysisto highlight that the method does, in fact, perform varied complex transfor-mations in feature space (as evidenced by t-SNE and nearest neighbor imagesamples). To summarize, our key contributions include: – We develop a learned feature-based refinement and augmentation module totransform input image features in the abstract feature space by leveraging asmall set of representative prototypes of all classes in the dataset. – We propose a memory bank mechanism to efficiently extract prototypes fromimages of the entire dataset with minimal extra computations. – We demonstrate thorough results across four standard SSL datasets and alsopropose a realistic setting where the unlabeled data partially come from do-mains shifted from the target labeled set. – We perform in-depth analysis of the prototype representations extracted andused for each instance, as well as what transformations the proposed feature-based refinement and augmentation module learns.
Consistency Regularization Methods.
Current state-of-the-art SSL meth-ods mostly fall into this category. The key insight of this branch of methodsis that the prediction of a deep model should be consistent across different semantic-preserving transformations of the same data. Consistency regulariza-tion methods regularize the model to be invariant to textural or geometricchanges of an image. Specifically, given an input image x and a network composedof a feature encoder f x = Enc ( x ) and a classifier p x = Clf ( f x ), we can gener-ate the pseudo-label of the input image by p x = Clf ( Enc ( x )). Furthermore,given a data augmentation module AugD ( · ), we can generate an augmentedcopy of x by ˆ x = AugD ( x ). A consistency loss H , typically KL-Divergence loss,is then applied on the model predictions of ˆ x to enforce consistent prediction: L con = H ( p, Clf ( Enc (ˆ x ))). Image-Based Augmentation.
The core to consistency-based methods is howto generate diverse but reasonable transformations of the same data. A straight-forward answer is to incorporate data augmentation, which has been widely usedin the training of a deep model to increase data diversity and prevent overfitting.For example, [4,16,25,27] use traditional data augmentation to generate differ-ent transformations of semantically identical images. Data augmentation method
Kuo et al.
Table 1: Comparison to other SSL methods with consistency regularization.
ReMixMatch[3] MixMatch[4] Mean Teacher[27] ICT[31] PLCB[1]
FeatMatch (Ours)
Feature-Based Augmentation - - - - - " Image-Based Augmentation " " " " " "
Temporal Ensembling " " " -Self-Supervised Loss " - - - - -Alignment of Class Distribution " - - - " - randomly perturbs an image in terms of its texture, eg. brightness, hue, sharp-ness, or its geometry, eg. rotation, translation, or affine transform. In addition todata augmentation, Miyato et al. [19] and Yu et al. [33] perturbed images alongthe adversarial direction, and Qiao et al. [22] use multiple networks to generatedifferent views (predictions) of the same data. Recently, several works proposedata augmentation modules for supervised learning or semi-supervised learning,where the augmentation parameters can either be easily tuned [8], found byRL-training [7], or decided by the confidence of network prediction [3].Mixup [35,34,34,13], similar to data augmentation, is another effective wayof increasing data diversity. It generates new training samples by a convex com-bination of two images and their corresponding labels. It has been shown thatmodels trained with Mixup is robust toward out-of-distribution data [10] andis beneficial for the uncertainty calibration of a network [28]. Given two images x and x and their labels (or pseudo labels) y and y , they are mixed by arandomly sampled ratio r by ˆ x = r · x + (1 − r ) · x and ˆ y = r · y + (1 − r ) · y .This has been done in feature space as well [30]. A standard classification loss H ( · ) is then applied on the prediction of the mixed sample ˆ x and the mixedlabel ˆ y by L mix = H (ˆ y, Clf ( Enc ((ˆ x ))). Originally, Mixup methods were devel-oped for supervised learning. ICT [31] and MixMatch [4] introduce Mixup intosemi-supervised learning by using the pseudo-label of the unlabeled data. Fur-thermore, by controlling the mixing ratio r to be greater than 0.5 as proposedby [4], we can make sure that the mixed sample is closer to x . Therefore, wecan separate the mixed data into labeled mixed batch ˆ X if x is labeled, andunlabeled mixed batch ˆ U if x is unlabeled. Different loss weights can then beapplied to modulate the strength of regularization from the unlabeled data. Image-based augmentation has been shown to be an effective approach to gen-erate different views of an image for consistency-based SSL methods. However,conventional image-based augmentation has the following two limitations: (1)Operate in image space, which limits the possible transformations to textural orgeometric within images, and (2) Operate within a single instance, which fails totransform data with the knowledge of other instances, either within or outsideof the same class. Some recent works that utilize Mixup only partially addressthe second limitation of conventional data augmentation since mixup operates eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 5 only between two instances. On the other hand, Manifold Mixup [30] approachesthe first limitation by performing Mixup in the feature space but is limited to asimple convex combination of two samples.We instead propose to address these two limitations simultaneously. We pro-posed a novel method that refines and augments image features in the abstractfeature space rather than image space. To efficiently leverage the knowledgeof other classes, we condense the information of each class into a small set ofprototypes by performing clustering in the feature space. The image featuresare then refined and augmented through information propagated from proto-types of all classes. We hypothesize that this feature refinement/augmentationcan further improve the feature representations, and these refined features canproduce better pseudo-labels than features without the refinement (See Sec. 4.4for our analysis on this hypothesis). The feature refinement and augmentationare learned via a lightweight attention network for the representative prototypesand optimized end-to-end with other objectives such as classification loss. A con-sistency loss can naturally be applied between the prediction from the originalfeatures and the refined features to regularize the network as shown in Fig. 1b.The final model seamlessly combines our novel feature-based augmentationwith conventional image-based augmentation for consistency regularization, whichis applied to data augmented from both sources. Despite the simplicity of themethod, we find this achieves significant performance improvement. In summary,we compare our method with other highly relevant SSL works in Table. 1.
In order to efficiently leverage the knowledge of other classes for feature refine-ment and augmentation, we propose to compactly represent the information ofeach class by clustering in the feature space. To select representative prototypesfrom the dataset, we propose to use K-Means clustering in the feature space toextract p k cluster means as prototypes for each class . However, there are twotechnical challenges: (1) in an SSL setting, most images are unlabeled; (2) evenif all the labels are available, it is still computationally expensive to extractfeatures of all the images from the entire dataset before running K-Means.To tackle these issues, as shown in Fig. 2, we collect features f xi and pseudo-labels ˆ y i already generated by the network at every iteration of the training loop, i.e., no extra computation needed. In the recording loop, the pairs of pseudolabel and features are detached from the computation graph and pushed into amemory bank for later usage. The prototypes are extracted by K-Means at everyepoch when we go over the whole dataset. Finally, the feature refinement andaugmentation module updates the prototypes with the newly extracted ones inthe training loop. Even though the prototypes are extracted from the featurecomputed from the model a few iterations ago, as training progresses and themodel gradually converges, the extracted prototypes fall on the correct clusterand are diverse enough to compactly represent the feature distribution per class.More analyses can be found in Sec. 4.4. Similar idea is concurrently explored inself-supervised learning by He et al. [32,11]. Kuo et al.
PrototypeSelection ( 𝑓𝑓 𝑥𝑥1 , �𝑦𝑦 )( 𝑓𝑓 𝑥𝑥2 , �𝑦𝑦 )( 𝑓𝑓 𝑥𝑥3 , �𝑦𝑦 )( 𝑓𝑓 𝑥𝑥𝑁𝑁 , �𝑦𝑦 𝑁𝑁 ) …… . RecordedfeaturesK-Means
Every epoch
Prototypesat epoch t+1
Update
TrainingLoopRecordingLoop
Prototypesat epoch t feature 𝑓𝑓 𝑥𝑥𝑖𝑖 pseudo label �𝑦𝑦 𝑖𝑖 𝑥𝑥 𝑓𝑓 𝑥𝑥 𝑔𝑔 𝑥𝑥 𝑝𝑝 ( 𝑦𝑦 | 𝑔𝑔 𝑥𝑥 ) 𝑝𝑝 ( 𝑦𝑦 | 𝑓𝑓 𝑥𝑥 ) Consistency Loss
Feature-basedAugmentation
Encoder f c l a y e r s o f t m a x f c l a y e r s o f t m a x Fig. 2: A prototype recording loop that runs alongside the model training loop.The image features f xi as well as their pseudo labels ˆ y i already generated at eachiteration of the training loop are collected and recorded in a memory bank as( f xi , ˆ y i ) pairs. Once the training loop goes over the whole dataset, the recordingloop will run K-Means to extract prototypes for each class, update the prototypesfor feature-based augmentation, and clear the memory bank. With a set of prototypes selected by the process described above, we proposea learned feature refinement and augmentation module via soft-attention [29]toward the set of selected prototypes. The proposed module refines and aug-ments input image features in the feature space by leveraging the knowledge ofprototypes, either within or outside of the same class, as shown in Fig. 3. Thelightweight feature refinement and augmentation module composed of three fullyconnected layers is jointly optimized with other objectives and hence learns areasonable feature-based augmentation to aid classification. We provide furtheranalysis and discussion in Sec. 4.4.Inspired by the attention mechanism [29], each input image feature attends to prototype features via attention weights computed by dot product similarity.The prototype features are then weighted summed by the attention weightsand then fed back to the input image feature via residual connect for featureaugmentation and refinement. Specifically, for an input image with extractedfeatures f x and the i -th prototype features f p,i , we first project them into anembedding space by a learned function φ e as e x = φ e ( f x ) and e p,i = φ e ( f p,i )respectively. We compute an attention weight w i between e x and e p,i as: w i = softmax( e Tx e p,i ) , (1)where softmax( · ) normalizes the dot product similarity scores across all proto-types. The information aggregated from the prototypes and passed to the image eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 7 Feature Augmentation 𝑥𝑥 𝑓𝑓 𝑥𝑥 𝑔𝑔 𝑥𝑥 Prototypesper class f c l a y e r 𝜙𝜙 𝑒𝑒 s o f t m a x , f c l a y e r 𝜙𝜙 𝑒𝑒 f c l a y e r 𝜙𝜙 𝑎𝑎 f c l a y e r 𝜙𝜙 𝑟𝑟 Encoder
Fig. 3:
Feature-Based Augmentation:
The input image features are aug-mented by attention using extracted prototype features (Eq. 1), where the colorsof (cid:63) represent the classes of prototypes. The prototype features are calcluatedvia a weighted sum using the attention weights, concatenated with the imagefeatures, and then undergo a fc layer φ a (Eq. 2) to produce attention features f a . Finally, we use the attention features to refine and augment the input imagefeatures with a residual connection (Eq. 3).features for feature refinement and augmentation can then be expressed as asum of prototype features weighted by the attention weights: f a = relu( φ a ([ e x , (cid:88) i w i e p,i ])) , (2)where φ a is a learnable function, and [ · , · ] is a concatenation operation along thefeature dimension. Finally, the input image features f x is refined via a residualconnection as: g x = relu( f x + φ r ( f a )) , (3)where g x are the refined features of f x , and φ r is a learnable function.The attention mechanism described above can be trivially generalized tomulti-head attention as in [29]. In practice, we use multi-head attention, insteadof single head for slightly better results. For simplicity, we define the feature re-finement and augmentation process AugF ( · ) described above as g x = AugF ( f x ). The learned
AugF module along with the selected prototypes provides an effec-tive method for feature-based augmentation, which addresses the limitations ofconventional data augmentation methods discussed previously. With the learnedfeature-based augmentation, we can naturally apply a consistency loss betweenthe prediction of unaugmented features f x and augmented features g x .However, given a classifier p = Clf ( f ), which prediction should we use aspseudo-label, p g = Clf ( g x ) or p f = Clf ( f x )? We investigate this problem inSec. 4.4 and find that AugF is able to refine the input features for better repre-sentation, thus generating better pseudo-labels. Therefore, we compute pseudo-label p g on the refined feature g x by p g = Clf ( g x ). The feature-based consistency Kuo et al. loss can be computed as: L con = H ( p g , Clf ( f x )). We can easily extend L con towork seamlessly with traditional augmentation methods, i.e., traditional dataaugmentation and Mixup. For simplicity, we will illustrate with only data aug-mentation, but Mixup can be easily adapted. Inspired by Berthelot et al. [3],we generate a weakly augmented image x and its strongly augmented copyˆ x . The pseudo-label is computed on the weakly augmented image x that un-dergoes feature-based augmentation and refinement for better pseudo-labels as p g = Clf ( AugF ( Enc ( x ))). We can then compute two consistency losses on thestrongly augmented data ˆ x , one with AugF applied and the other without: L con - g = H ( p g , Clf ( AugF ( Enc (ˆ x ))) (4) L con - f = H ( p g , Clf ( Enc (ˆ x ))) (5)Since the pseudo-label p g is computed on the image undergoing weak data aug-mentation and feature-based augmentation, the regularization signal of L con - g and L con - f comes from both image-based and feature-based augmentation. Consistency regularization losses L con - g and L con - f in Eq. 4 and 5 are appliedon unlabeled data. For labeled image x with label y , a regular classification losscan be applied: L clf = H ( y, Clf ( AugF ( Enc ( x )))) (6)Therefore, the total loss can be written as: L clf + λ g L con - g + λ f L con - f . Where λ g and λ f are weights for L con - g and L con - f losses respectively. We conduct experiments on commonly used SSLdatasets: SVHN [20], CIFAR-10 [15], CIFAR-100 [15], and mini-ImageNet [23].Following the standard approach in SSL, we randomly choose a certain numberof labeled samples as a small labeled set and discard the labels for the remainingdata to form a large unlabeled set. Our proposed method is tested under variousamounts of labeled samples. SVHN is a dataset of 10 digits, which has about 70ktraining samples. CIFAR-10 and CIFAR-100 are natural image datasets with 10and 100 classes respectively. Both dataset contains 50k training samples. Formini-ImageNet, we follow [14,1] to construct the mini-ImageNet training set.Specifically, given a predefined list of 100 classes [23] from ILSVRC [24], 500samples are selected randomly for each class, thus forming a training set of 50ksamples. The samples are center-cropped and resized to 84x84 resolution. Wethen follow the same standard procedure and construct a small labeled set anda large unlabeled set from the 50k training samples. eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 9
Table 2: Comparison on CIFAR-100 and mini-imageNet. Numbers represent er-ror rate in three runs. For fair comparison, we use the same model as othermethods: CNN-13 for CIFAR-100 and ResNet-18 for mini-ImageNet.
CIFAR-100 mini-ImageNet Π -model [25] - 39.19 ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± SSL under domain shift.
In another realistic setting, we argue that the un-labeled data may come from a domain different from that of the target labeleddata. For instance, given a small set of labeled natural images of animals, thelarge unlabeled set may also contain paintings of animals. To investigate theeffect of domain shift in the unlabeled set, we proposed a new SSL task based onthe DomainNet dataset [21], which contains 345 classes of images coming fromsix domains: Clipart, Infograph, Painting, Quickdraw, Real, and Sketch.We use the
Real domain as our target. Five percent of the data from the Realdomain are kept as the target labeled set, and the rest are the target unlabeledset. We select
Clipart , Painting , Sketch , and
Quickdraw as shifted domains. Tomodulate the level of domain shift in the unlabeled data, we propose a parameter r u that controls the ratio of unlabeled data coming from the target Real domainor the shifted domains. Specifically, r u percent of target Real unlabeled set isreplaced with data uniformly drawn from the shifted domains. By formulatingthe problem this way, the amount of unlabeled data remains constant. The onlyfactor that affects the performance of the proposed method is the ratio betweenin-domain data and shifted domain data in the unlabeled set.We randomly reserve 1% of data from the Real domain as the validationset. The final result is reported on the test set of the Real domain, with themodel selected on the reserved validation set. The images are center-croppedand resized to 128x128 resolution, and the model we use is the standard ResNet-18 [12]. There are around 120k training samples, which is more than twice largerthan the standard SSL datasets such as CIFAR-10 and CIFAR-100. For a faircomparison, we fix all hyper-parameters across experiments of different r u totruly assess the robustness of proposed methods toward domain shift in theunlabeled data. Hyper-parameters.
We tune the hyper-parameters on CIFAR-10 with 250labels with a validation set held-out from the training set. Our method is notsensitive to the hyper-parameters, which are kept fixed across all the datasets
Table 3: Comparison between the image-based baseline with our proposedfeature-based augmentation method on DomainNet with 1) unlabeled data com-ing from the same domain as the labeled target ( r u = 0%), and 2) half of un-labeled data coming from the same domain as the labeled target and the otherhalf from shifted domains ( r u = 50%). Numbers are error rates across 3 runs. Method (5% labeled samples) r u = 0% r u = 50%(Semi-supervised) Baseline 56.63 ± ± ± ± ± ± and settings. Please see the supplementary for more implementation details andthe values of hyper-parameters. We first show our results on CIFAR-100 and mini-ImageNet with 4k and 10klabels in Table 2. Our method consistently improves over state of the arts bylarge margins, with about absolute 5% on CIFAR-100 with 4k labels and 17%on mini-ImageNet with 4k labels.In Table 3, we show our results on the larger dataset of DomainNet setting,which contains unlabeled data coming from other shifted domains. It can beclearly seen that in the setting of r u = 50%, where 50% of the unlabeled dataare coming from other shifted domains, the performance drops by a large margincompared with the setting of r u = 0%, where all the unlabeled data are com-ing from the same domain as the target labeled set. Nevertheless, our proposedfeature-based augmentation method improves over supervised baseline by abso-lute 36% error rate when r u = 0% and 23% when r u = 50%. When compared tothe conventional image-based augmentation baseline, we improves by 12% when r u = 50% and 16% when r u = 0%.In Table 4, we show the comparison of our method with other SSL methodson standard CIFAR-10 and SVHN datasets. Our method achieves comparableresults with the current state of the art, ReMixMatch, even though 1) we startfrom a lower baseline and 2) our method is much simpler ( e.g., no class dis-tribution alignment and no self-supervised loss), as compared in Table 1. Ourproposed feature-based augmentation method is complementary to image-basedmethods and can be easily integrated to further improve the performance. In the ablation study, we are interested in answering the following questions: 1)what is the effectiveness of the two proposed consistency losses – L con - f (Eq. 5)and L con - g (Eq. 4). 2) how much of the improvement is from the proposed feature-based augmentation method over the image-based augmentation baseline? For eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 11 Table 4: Comparison on CIFAR-10 and SVHN. Numbers represent errorrate across three runs. The results reported in the first block with CNN-13model [16,19] are from the original paper. The results reported in the secondblock with wide ResNet (WRN) are reproduced by [4,3].
CIFAR-10 SVHN ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± Π -model [25] WRN (1.5M) 53.02 ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± ± FeatMatch (Ours) 7.50 ± ± ± ± ± ± Table 5: Ablation study on CIFAR-10 with various amount of labeled samples. L con - f L con - g L con
250 1,000 4,000Baseline " - - - " ± ± ± L con - f " " - " - 18.57 ± ± ± L con - g " " " - - 8.19 ± ± ± " " " " - 7.90 ± ± ± the image-based augmentation baseline, the AugF module is completely removedand thus the consistency regularization comes only from image-based augmen-tation. This is also the same image-based augmentation baseline that our finalmodel with feature-based augmentation builds upon. The ablation study is con-ducted on CIFAR-10 with various amount of labeled samples (Table 5).We can see from Table 5 that our image-based augmentation baseline achievesgood results but only on cases where there are more labeled samples. We con-jecture this is because the aggressive data augmentation applied to trainingimages makes the training unstable. Nevertheless, our baseline performance isstill competitive with respect to other image-based augmentation methods inTable 4 (though slightly worse than MixMatch). By adding our proposed
AugF module( L con - f and L con - g ) for feature refinement and augmentation on top ofthe image-based augmentation baseline, the performance improves over baselineconsistently, especially for 250 labels.We can also see that L con - f plays a more important role than L con - g , thoughour final model with both loss terms achieves the best result. In both L con - f and Augmentation
Horse AugDAugF Airplane AugDAugFShip AugDAugFCar AugFAugD (c) Augmented images of AugD vs. AugF
Fig. 4: ( a ) We jointly compute and plot t-SNE of input unaugmented imagefeatures (dimmer color) and image-based augmented features (brighter color).( b ) We also jointly compute and plot t-SNE of input unaugmented image fea-tures (dimmer color) and feature-based augmented features (brighter color) withthe exact same t-SNE parameters with (a). ( c ) To concretely visualize the aug-mented feature, we find their nearest image neighbor in the feature space andcompare against the image-based augmentation method side by side. L con - g , the pseudo-labels are computed from the features undergone feature-based augmentation. The only difference is which prediction we’re driving tomatch the pseudo-label: 1) the prediction from the feature undergone both AugD and
AugF (by L con - g loss), or 2) the prediction from the feature undergone only AugD (by L con - f loss)? As claimed in Sec. 3.3 and analyzed in Sec. 4.4, AugF is able to refine input image features for better representation and pseudo-labelsof higher quality. Therefore, matching the slightly worse prediction from thefeature undergone only
AugD (by L con - f loss) induces a stronger consistencyregularization. This explains why L con - f improves performance more crucially. AugF learn?
We compare the feature distributionvia t-SNE 1) between input unaugmented image features and image-based aug-mented features in Fig. 4a, and 2) between input unaugmented image featuresand feature-based augmented features in Fig. 4b. In Fig. 4a, some local smallclusters are captured by t-SNE and can be found in the zoomed sub-figure.This indicates that
AugD can only perturb data locally, and fail to producestronger augmentation for more effective consistency regularization in the fea- eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 13(a) t-SNE of selected prototypes. edges
Prototypes I m a g e s (b) Leaned attention weights. Prototypes
AutomobileAirplane BirdCat (c) Nearest image neighbors of prototypes
Fig. 5: ( a ) In the t-SNE plot, the extracted prototypes ( (cid:5) ) fall on the correctclusters and are able to compactly represent the cluster. ( b ) We visualize thelearned attention weights from a batch of images toward prototypes. The imagesand prototypes are sorted by their classes for ease of illustration. As can be seen,images have higher attention weights to the prototypes with the same class.( c ) We find the prototypes’ nearest image neighbors in the feature space. Theprototypes compactly represent a diverse sets of images in each class.ture space. In Fig. 4b, we can see AugF indeed learns to augment and refinefeatures. Furthermore, the learned augmentation preserves semantic meaning asthe augmented features still fall in the correct cluster. In the zoomed figure,we can see that the perturbed features distribute more uniformly and no localsmall clusters could be found. This indicates that
AugF can produce strongeraugmentation for more effective consistency regularization in the feature space.To have a more concrete sense of the learned feature-based augmentation(
AugF ), we show the augmented feature’s nearest image neighbor in the fea-ture space. Some sample results are shown in Fig. 4c, with the comparison toimage-based augmentation (
AugD ) side by side. As shown in the figure,
AugF is capable of transforming features in an abstract way, which goes beyond simpletextural and geometric transformation as
AugD does. For instance, it is able toaugment data to different poses and backgrounds, which could be challengingfor conventional image-based augmentation methods.
What other reason does
AugF improve model performance?
We hy-pothesize that one other reason why our method can improve performance isthat
AugF module is capable of refining input image features for better represen-tation by the extracted prototypes, and thus provides better pseudo-labels. Theconsistency regularization losses then drive the network’s prediction to matchthe target pseudo-labels of higher quality, leading to overall improvement. Withthis hypothesis, we expect classification accuracy to be higher for features after feature refinement. To verify, we remove L con - f loss and retrain. The accuracyof pseudo-labeling from the features refined by AugF is on average 0 . − . L con - f drives the feature encoder tolearn a better feature representation refined by AugF .The reader may wonder: why doesn’t
AugF learn a shortcut solution of iden-tity mapping to minimize L con - f and L con - g ? As can be seen from Fig. 4, AugF does not learn an identity mapping. Although learning an identity mapping maybe a shortcut solution for minimizing L con - f and L con - g , it is not the case for theclassification loss L clf (Eq. 6). This finding implicitly confirms our hypothesisthat there is extra information from the prototypes that AugF can leverage torefine the feature representation for higher (pseudo-label) classification accuracy.
What does
Aug do internally?
In Fig. 5a and 5c, we can see that eventhough our proposed prototype extraction method only uses simple K-Meansto extract prototypes of each class based on potentially noisy pseudo-labels,and features recorded several iterations ago, our prototype selection methodcan still successfully extract a diverse set of prototypes per class. Moreover, inFig. 5b, the attention mechanism inside
AugF learns to attend to prototypesthat belong to the same class with the input image feature. Note that there is noloss term specific for
AugF , as it is simply jointly optimized with the standardclassification and consistency regularization loss from semi-supervised learning.
We introduce a method to jointly learn a classifier and feature-based refine-ment and augmentations which can be used within existing consistency-basedSSL methods. Unlike traditional image-based transformations, our method canlearn complex, feature-based transformations as well as incorporate informationfrom class-specific prototypical representations extracted in an efficient manner(specifically using a memory bank). Using this method, we show comparableresults as the current state of the art for smaller datasets such as CIFAR-10 andSVHN, and significant improvements on datasets with a large number of cate-gories ( e.g. , 17.44% absolute improvement on mini-ImageNet). We also demon-strate increased robustness to out-of-domain unlabeled data, which is an impor-tant real-world problem, and perform ablations and analysis to demonstrate thelearned feature transformation and extracted prototypical representations.
This work was funded by DARPA’s Learning with Less Labels (LwLL) programunder agreement HR0011-18-S-0044 and DARPAs Lifelong Learning Machines(L2M) program under Cooperative Agreement HR0011-18-2-0019. eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 15
AppendixA State-of-the-art Results with other SSL Techniques
As we build upon a weaker baseline and our method is much simpler, the per-formance of our method on the CIFAR-10 dataset is slightly worse in somesettings. However, as we claimed in Section 4.2 of the main paper, our proposedfeature-based augmentation method is complementary to conventional image-based augmentation methods and can be easily integrated to further improvethe performance. In Table 6, we demonstrate that by incorporating (1) distribu-tion alignment that aligns the marginal class distribution as described in [1,3],and (2)
Cutout [9], an image-based augmentation method, our method indeedcompares favorably against current state-of-the-art algorithms. Note that ourmethod is still simpler when compared to state-of-the-art image-based method, e.g.,
ReMixMatch [3]. For example, the ReMixMatch method also incorporatesself-supervsied loss, temporal ensembling of model weights, and tailored dataaugmentation method (CTAugment [3]), etc.Table 6: Comparison to other state-of-the-art methods after incorporating someother modern SSL techniques (distribution alignment and Cutout). We showthe results on the CIFAR-10 dataset with varying amounts of labeled samples.Numbers represent error rate across three runs. ± ± ± ± ± ± ± ± ± FeatMatch (Ours with other SSL techniques) ± ± ± B Pseudo-Labeling Accuracy Before and After
AugF
In Section 4.4 of the main paper, we analyze other reasons that
AugF improvesmodel performance. We conclude that our proposed
AugF module also learns torefine input feature for a better representation by attending to the prototypes.This feature refinement process by
AugF provides the training objectives of L con - g (Eq. 5) and L con - f (Eq. 6) with better pseudo-labels, which may be oneof the reasons why our method can improve over image-based baseline by a largermargin. In Fig. 6 below, we can see that the accuracy of pseudo-labels from thefeatures refined by AugF is higher than those without refinement by
AugF . Fig. 6: We monitor the accuracy of pseudo-labeling with feature-base refinement(red curve) and without feature-based refinement (blue curve) during training.We found that the pseudo-label from the refined feature (red) has on average0 . − .
0% higher accuracy.
C More Analysis on Prototypes
We test the sensitivity of our method for the hyper-parameter p k (number ofprototypes per class) and I p (the interval at which a new set of prototypes isextracted). The analysis is conducted on a held-out validation set of the CIFAR-10 dataset with 250 labels. As shown in Table 7, the final results are stable acrossdifferent values of p k . We choose the number of prototypes per class p k = 20 inour method as it performs slightly better than others and has a slightly lowervariance. In Table 8, we can see that the final results are also stable acrossdifferent I p . Therefore, for simplicity, we extract prototypes every epoch, whichis approximately the same as I p = 400.Table 7: Sensitivity analysis for p k . Numbers represent error rates in three runs. p k = 1 p k = 5 p k = 10 p k = 208.14 ± ± ± ± Table 8: Sensitivity analysis for I p . Numbers represent error rates in three runs. I p = 200 I p = 400 I p = 600 I p = 8008.00 ± ± ± ± eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 17 D More Results on the DomainNet Setting
In Section 4.1, we propose a practical setting where the unlabeled data may comefrom other domains. We show results with different r u , the ratio of unlabeleddata coming from the target Real domain or the shifted domains, in Section4.2. In this section, we show additional results of r u = 0 .
25 and r u = 0 .
75 onboth our method and the image-based baseline in Tab. 9. The results show asimilar trend is similar as the Table 3 in the main paper, where the accuracygoes down as r u goes up. Our method consistently improves over image-basedsemi-supervised baseline. Our method achieves comparable result even in thesevere case of r u = 75% against the image-based baseline method with cleanunlabeled data of r u = 0%Table 9: Comparison between the image-based baseline with our proposedfeature-based augmentation method on DomainNet with various r u , the ratioof unlabeled data coming from the shifted domains. For instance, r u = 25%means 25% of the unlabeled data are coming from the shifted domains and 75%are coming from the domain same as the labeled set. Numbers are error ratesacross 3 runs, meaning the lower the better. Method (5% labeled samples) r u = 0% r u = 25% r u = 50% r u = 75%(Semi-supervised) Baseline 56.63 ± ± ± ± ± ± ± ± Supervised baseline (5% labeled samples, lower bound) 77.25 ± ± E Implementation Details
E.1 Training
We train our model with Stochastic Gradient Descent and Nesterov momentum.As the
AugF module heavily relies on the feature representation to computeattention weights, we pre-train the model without
AugF for 4 epochs.We adapt the super convergence learning rate scheduler [26] to reduce thetotal training iterations. Specifically, in the pre-training stage, the learning ratestarts from 4e-4 and linearly increase to 4e-3 in I p iterations. After the pre-training stage, we add the AugF module and ramp up the learning rate linearlyfrom 4e-3 to 4e-2 in I c iterations, and then ramp down back to 4e-3 in another I c iterations. In the meantime, the momentum ramps down from 0.95 to 0.85,and then ramps up back to 0.95. Finally, in the convergence stage, the learningrate ramps further down from 4e-3 to 4e-6 in I e iterations.We follow the guidelines in [26] to set these parameters without aggressiveparameter tuning, and set I p = 3 k , I c = 75 k , and I e = 30 k . As the DomainNet setting has more training samples, we increase these values I p = 4 k , I c = 100 k ,and I e = 40 k without tuning. We only tune the peak learning rate to be 4e-4 ona held-out validation set on CIFAR-10 with 250 labels. E.2 Hyper-parameters
All the hyper-parameters are tuned on a held-out validation set on CIFAR-10 with 250 labels. These hyper-parameters are shared across all settings andexperiments without further tuning. Since our method is built upon the image-based baseline, we fix the hyper-parameters or select a reasonable value withouttuning from the original papers.Table 10: Hyper-parameters and their meanings.
Hyper-parameter Description Value p k Number of prototypes per class 20 I p The interval at which a new set of prototypes are extracted 1 epoch a h Number of attention heads in
AugF λ g Loss weight for L con - g λ f Loss weight for L con - f b l Batch size for labeled data 64 b u Batch size for unlabeled data 128 wd Weight decay 2e-4
E.3 Data Augmentation Operations
We used the same sets of image transformations used in RandAugment [8]. Thereare two parameters in RandAugment: (1) N – number of operations applied,and (2) M – maximal magnitude of the applied augmentation. We use N = 2as in RandAugment, and set M to its max value without tuning. Note that themagnitude is randomly sampled from [ − M, M ]. References
1. Arazo, E., Ortego, D., Albert, P., O’Connor, N.E., McGuinness, K.: Pseudo-labeling and confirmation bias in deep semi-supervised learning. arXiv preprintarXiv:1908.02983 (2019) 4, 8, 9, 11, 152. Athiwaratkun, B., Finzi, M., Izmailov, P., Wilson, A.G.: Improving consistency-based semi-supervised learning with weight averaging. arXiv preprintarXiv:1806.05594 (2018) 9, 113. Berthelot, D., Carlini, N., Cubuk, E.D., Kurakin, A., Sohn, K., Zhang, H., Raf-fel, C.: Remixmatch: Semi-supervised learning with distribution alignment andaugmentation anchoring. In: Proc. International Conference on Learning Repre-sentations (ICLR) (2020) 1, 4, 8, 11, 15eatMatch: Feature-Based Augmentation for Semi-Supervised Learning 194. Berthelot, D., Carlini, N., Goodfellow, I., Papernot, N., Oliver, A., Raffel, C.A.:Mixmatch: A holistic approach to semi-supervised learning. In: Advances in NeuralInformation Processing Systems. pp. 5050–5060 (2019) 1, 3, 4, 11, 155. Chen, Y., Zhu, X., Gong, S.: Semi-supervised deep learning with memory. In:Proceedings of the European Conference on Computer Vision (ECCV). pp. 268–283 (2018) 9, 116. Cubuk, E.D., Zoph, B., Mane, D., Vasudevan, V., Le, Q.V.: Autoaugment: Learningaugmentation policies from data. arXiv preprint arXiv:1805.09501 (2018) 17. Cubuk, E.D., Zoph, B., Mane, D., Vasudevan, V., Le, Q.V.: Autoaugment: Learningaugmentation strategies from data. In: Proceedings of the IEEE conference oncomputer vision and pattern recognition. pp. 113–123 (2019) 48. Cubuk, E.D., Zoph, B., Shlens, J., Le, Q.V.: Randaugment: Practical automateddata augmentation with a reduced search space. arXiv preprint arXiv:1909.13719(2019) 1, 4, 189. DeVries, T., Taylor, G.W.: Improved regularization of convolutional neural net-works with cutout. arXiv preprint arXiv:1708.04552 (2017) 1510. Guo, H., Mao, Y., Zhang, R.: Mixup as locally linear out-of-manifold regularization.In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 33, pp. 3714–3722 (2019) 411. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervisedvisual representation learning. arXiv preprint arXiv:1911.05722 (2019) 512. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In:Proceedings of the IEEE conference on computer vision and pattern recognition.pp. 770–778 (2016) 913. Hendrycks, D., Mu, N., Cubuk, E.D., Zoph, B., Gilmer, J., Lakshminarayanan, B.:AugMix: A simple data processing method to improve robustness and uncertainty.Proceedings of the International Conference on Learning Representations (ICLR)(2020) 1, 414. Iscen, A., Tolias, G., Avrithis, Y., Chum, O.: Label propagation for deep semi-supervised learning. In: Proceedings of the IEEE Conference on Computer Visionand Pattern Recognition. pp. 5070–5079 (2019) 8, 9, 1115. Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tinyimages. Tech. rep., Citeseer (2009) 816. Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. In: Proc.International Conference on Learning Representations (ICLR) (2017) 1, 3, 1117. Lee, D.H.: Pseudo-label: The simple and efficient semi-supervised learning methodfor deep neural networks. In: Workshop on Challenges in Representation Learning,ICML. vol. 3, p. 2 (2013) 1118. Luo, Y., Zhu, J., Li, M., Ren, Y., Zhang, B.: Smooth neighbors on teacher graphsfor semi-supervised learning. In: Proceedings of the IEEE Conference on ComputerVision and Pattern Recognition. pp. 8896–8905 (2018) 9, 1119. Miyato, T., Maeda, S.i., Koyama, M., Ishii, S.: Virtual adversarial training: a regu-larization method for supervised and semi-supervised learning. IEEE transactionson pattern analysis and machine intelligence (8), 1979–1993 (2018) 1, 4, 1120. Netzer, Y., Wang, T., Coates, A., Bissacco, A., Wu, B., Ng, A.Y.: Reading digitsin natural images with unsupervised feature learning (2011) 821. Peng, X., Bai, Q., Xia, X., Huang, Z., Saenko, K., Wang, B.: Moment matchingfor multi-source domain adaptation. In: Proceedings of the IEEE InternationalConference on Computer Vision. pp. 1406–1415 (2019) 3, 90 Kuo et al.22. Qiao, S., Shen, W., Zhang, Z., Wang, B., Yuille, A.: Deep co-training for semi-supervised image recognition. In: Proceedings of the European Conference on Com-puter Vision (ECCV). pp. 135–152 (2018) 1, 4, 9, 1123. Ravi, S., Larochelle, H.: Optimization as a model for few-shot learning. In: Inter-national Conference on Learning Representations (ICLR) (2017) 824. Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z.,Karpathy, A., Khosla, A., Bernstein, M., et al.: Imagenet large scale visual recog-nition challenge. International journal of computer vision (3), 211–252 (2015)825. Sajjadi, M., Javanmardi, M., Tasdizen, T.: Regularization with stochastic trans-formations and perturbations for deep semi-supervised learning. In: Advances inNeural Information Processing Systems. pp. 1163–1171 (2016) 1, 3, 9, 1126. Smith, L.N., Topin, N.: Super-convergence: Very fast training of neural networksusing large learning rates. In: Artificial Intelligence and Machine Learning forMulti-Domain Operations Applications. vol. 11006, p. 1100612. International So-ciety for Optics and Photonics (2019) 1727. Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weight-averagedconsistency targets improve semi-supervised deep learning results. In: Advances inneural information processing systems. pp. 1195–1204 (2017) 1, 3, 4, 9, 1128. Thulasidasan, S., Chennupati, G., Bilmes, J.A., Bhattacharya, T., Michalak, S.: Onmixup training: Improved calibration and predictive uncertainty for deep neuralnetworks. In: Advances in Neural Information Processing Systems. pp. 13888–13899(2019) 429. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser,(cid:32)L., Polosukhin, I.: Attention is all you need. In: Advances in neural informationprocessing systems. pp. 5998–6008 (2017) 6, 730. Verma, V., Lamb, A., Beckham, C., Courville, A., Mitliagkis, I., Bengio, Y.: Man-ifold mixup: Encouraging meaningful on-manifold interpolation as a regularizer.stat1050