# gradient_matching_for_domain_generalization__9b23a153.pdf Published as a conference paper at ICLR 2022 GRADIENT MATCHING FOR DOMAIN GENERALIZATION Yuge Shi University of Oxford yshi@robots.ox.ac.uk Jeffrey Seely Meta Reality Labs jseely@fb.com Philip H.S. Torr University of Oxford philip.torr@eng.ox.ac.uk N. Siddharth The University of Edinburgh & The Alan Turing Institute n.siddharth@ed.ac.uk Awni Hannun Facebook AI Research awni@fb.com Nicolas Usunier Facebook AI Research usunier@fb.com Gabriel Synnaeve Facebook AI Research gab@fb.com Machine learning systems typically assume that the distributions of training and test sets match closely. However, a critical requirement of such systems in the real world is their ability to generalize to unseen domains. Here, we propose an inter-domain gradient matching objective that targets domain generalization by maximizing the inner product between gradients from different domains. Since direct optimization of the gradient inner product can be computationally prohibitive it requires computation of second-order derivatives - we derive a simpler first-order algorithm named Fish that approximates its optimization. We perform experiments on the WILDS benchmark, which captures distribution shift in the real world, as well as the DOMAINBED benchmark that focuses more on syntheticto-real transfer. Our method produces competitive results on both benchmarks, demonstrating its effectiveness across a wide range of domain generalization tasks. Code is available at https://github.com/Yuge Ten/fish. 1 INTRODUCTION Trajectory of during IDGM training inner product at initial point Trajectory of during ERM training Gradient of the loss for domain 1 Gradient of the loss for domain 2 product large inner Figure 1: Isometric projection of training with ERM (blue) vs. our objective (dark blue), using data from Figure 2. The goal of domain generalization is to train models that performs well on unseen, out-of-distribution data, which is crucial in practice for model deployment in the wild. This seemingly difficult task is made possible by the presence of multiple distributions/domains at train time. As we have seen in past work (Arjovsky et al., 2019; Gulrajani and Lopez-Paz, 2020; Ganin et al., 2016), a key aspect of domain generalization is to learn from features that remain invariant across multiple domains, while ignoring those that are spuriously correlated to label information (as defined in Torralba and Efros (2011); Stock and Cisse (2017)). Consider, for example, a model that is built to distinguish between cows and camels using photos collected in nature under different climates. Since CNNs are known to have a bias towards texture (Geirhos et al., 2018; Brendel and Bethge, 2019), if we simply try to minimize the average loss across different domains, the classifier is prone to spuriously correlate cow with grass and camels with desert, and predict the species using only the background. Such a classifier can be rendered useless when the animals are placed indoors or in a zoo. However, if the model could recognize that while the landscapes change with climate, the biological characteristics of the animals Work done during internship at Facebook AI Research. Now at Zoom. Published as a conference paper at ICLR 2022 (e.g. humps, neck lengths) remain invariant and use those features to determine the species, we have a much better chance at generalizing to unseen domains. Similar intuitions have already motivated many approaches that consider learning invariances across domains as the main challenge of domain generalization. Typically, a lot of these work focus on learning invariant representations directly by removing the domain information (Ganin et al., 2016; Sun and Saenko, 2016; Li et al., 2018). In this work, we propose an inter-domain gradient matching (IDGM) objective. Instead of learning invariant features by matching the distributions of representations from different domains, our approach does so by encouraging consistent gradient directions across domains. Specifically, our IDGM objective augments the loss with an auxiliary term that maximizes the gradient inner product between domains, which encourages the alignment between the domain-specific gradients. By simultaneously minimizing the loss and matching the gradients, IDGM encourages the optimization paths to be the same for all domains, favouring invariant predictions. Figure 1 illustrates a motivating example described in Section 3.2: given 2 domains, each containing one invariant feature (orange cross) and one spurious feature (yellow and red cross). While empirical risk minimization (ERM) minimizes the average loss between these domains at the cost of learning spurious features only, IDGM aligns the gradient directions and is therefore able to focus on the invariant feature. While the IDGM objective achieves the desirable learning dynamic in theory, naive optimization of the objective by gradient descent is computationally costly due to the second-order derivatives. Leveraging the theoretical analysis of Reptile, a meta-learning algorithm (Nichol et al., 2018), we propose to approximate the gradients of IDGM using a simple first-order algorithm, which we name Fish. Fish is simple to implement, computationally effective and as we show in our experiments, functionally similar to direct optimization of IDGM. Our contribution is a simple but effective algorithm for domain generalization, which exhibits state-ofthe-art performance on 13 datasets from recent domain generalization benchmark WILDS (Koh et al., 2020) and DOMAINBED (Gulrajani and Lopez-Paz, 2020). The strong performance of our method on a variety of datasets demonstrates that it is broadly applicable in different applications/subgenres of domain generalization tasks. We also perform detailed analysis in Section 4.4 to explain the effectiveness of our proposed algorithm. 2 RELATED WORK Domain Generalization In domain generalization, the training data is sampled from one or many source domains, while the test data is sampled from a new target domain. We will now discuss the five main families of approaches to domain generalization: 1. Distributional Robustness (DRO): DRO approaches minimize the worst-case loss over a set of data distributions constructed from the training domains. Rojas-Carulla et al. (2015) proposed DRO to address covariate shift (Gretton et al., 2009a;b), where P(Y |X) remains constant across domains but P(X) changes. Later work also studied subpopulation shift, where the train and test distributions are mixtures of the same domains, but the mixture weights change between train and test (Hu et al., 2018; Sagawa et al., 2019); 2. Domain-invariant representation learning: This family of approaches to domain generalization aims at learning high-level features that make domains statistically indistinguishable. Prediction is then based on these features only. The principle is motivated by a generalization error bound for unsupervised domain adaptation (Ben-David et al., 2010; Ganin et al., 2016), but the approach readily applies to domain generalization (Gulrajani and Lopez-Paz, 2020; Koh et al., 2020). Algorithms include penalising the domain-predictive power of the model (Ganin et al., 2016; Wang et al., 2019; Huang et al., 2020), aligning domains through contrastive loss (Motiian et al., 2017), matching mean and variance of feature distributions across domains (Sun and Saenko, 2016), learning useful representations by solving Jigsaw puzzles (Carlucci et al., 2019), using the maximum mean discrepancy to match the feature distributions (Li et al., 2018b) or introducing training constraints across domains using mixup formulation (Yan et al., 2020). 3. Invariant Risk Minimization (IRM): IRM is proposed by Arjovsky et al. (2019), which learns an intermediate representation such that the optimal classifiers (on top of this representation) of all domains are the same. The motivation is to exploit invariant causal effects between domains while Published as a conference paper at ICLR 2022 reducing the effect of domain-specific spurious correlations. From an optimization perspective, when IRM reaches its optimal, all the gradients (for the linear classifier) has to be zero. This is why IRM s solution won t deviate from ERM when ERM is optimal for every domain, which is not the case for our proposed IDGM objective due to the gradient inner product term. 4. Data augmentation: More recently, approaches that simulates unseen domains through specific types of data augmentation/normalization has been gaining traction. This includes work such as Zhou et al. (2020); Volpi and Murino (2019); Ilse et al. (2021), as well as Seo et al. (2019) which utilises ensemble learning. 5. Gradient alignment: Two concurrent work Koyama and Yamaguchi (2021) and Parascandolo et al. (2021) utilise similar gradient-alignment principle for domain generalization. Koyama and Yamaguchi (2021) proposes IGA, which learns invariant features by minimizing the variance of inter-domain gradients. The key difference between IGA and our objective is that IGA is completely identical to ERM when ERM is the optimal solution on every training domain, since the variances of the gradients will be zero. While they achieve the best performance on the training set, both IGA and ERM could in some cases, completely fail when generalizing to unseen domains (see Section 3.2 for such an example). Our method, on the contrary, biases towards non-ERM solutions as long as the gradients are aligned, and is therefore able to avoid this issue. Parascandolo et al. (2021) on the other hand, proposes to mask out the gradients that have opposite signs for different domains. Unlike their work that prunes gradients that are inconsistent, our approach actively encourage gradients from different domains to be consistent by maximizing the gradient inner product. Additionally, in Lopez-Paz and Ranzato (2017) we also see the application of gradientalignment, however in this case it is applied under the continual learning setting to determine whether a gradient update will increase the loss of the previous tasks. Apart from these algorithms that are tailored for domain generalization, a well-studied baseline in this area is ERM, which simply minimizes the average loss over training domains. Using vanilla ERM is theoretically unfounded (Hashimoto et al., 2018; Blodgett et al., 2016; Tatman, 2017) since ERM is guaranteed to work only when train and test distributions match. Nonetheless, recent benchmarks suggest that ERM obtains strong performance in practice, in many case surpassing domain generalization algorithms (Gulrajani and Lopez-Paz, 2020; Koh et al., 2020). Our goal is to fill this gap, using an algorithm significantly simpler than previous approaches. Connections to meta-learning There are close connections between meta-learning (Thrun and Pratt, 1998) and (multi-source) domain adaptation. In fact, there are a few works in domain generalization that are inspired by the meta-learning principles, such as Li et al. (2018a); Balaji et al. (2018); Li et al. (2019); Dou et al. (2019). Specifically, Li et al. (2020) also proposes to adapt Reptile for domain generalization tasks, however they study their method under the sequential learning setting, whereas our method can be trained on all domains and therefore learns faster, especially when the number of domains is large. In Ren et al. (2018), we also see the leveraging of gradient inner product in meta-learning, where it is used to determine the importance weight of training examples. We discuss the connection between our proposed algorithm to meta-learning in more details in Appendix A.1. Note that our proposed algorithm Fish is similar to the Mean Teacher method (Tarvainen and Valpola, 2017), where a teacher model (equivalent to θ in Algorithm 1) is computed using a moving average of the student model (equivalent to θ in Algorithm 1). 3 METHODOLOGY Consider a training dataset Dtr consisting of S domains Dtr = {D1, , DS}, where each domain s is characterized by a dataset Ds := {(xs i, ys i )}ns i=1 containing data drawn i.i.d. from some probability distribution. Also consider a test dataset Dte consisting of T domains Dte = {DS+1, , DS+T }, where Dtr Dte = . The goal of domain generalization is to train a model with weights θ that generalizes well on the test dataset Dte such that: arg min θ ED Dte E(x,y) D [l((x, y); θ)] , (1) Published as a conference paper at ICLR 2022 Figure 2: All 3 domains (rows) consist of 3 types of inputs (columns): 1) x1, left: makes up for 50% of each domain, label is always 0, x1 is always [0, 0, 0, 0]; 2) x2, middle: makes up for 40% of each domain, label is always 1, x2 changes for each domain; 3) x3, right: makes up for 10% of each domain, labels are randomly assigned with 30% of y = 1 and 70% of y = 0, x3 is always [1, 0, 0, 0]. where l((x, y); θ) is the loss of model θ evaluated on (x, y). A naive approach is to apply ERM, which simply minimizes the average loss on Dtr, ignoring the discrepancy between train and test domains: Lerm(Dtr; θ) = ED Dtr E(x,y) D [l((x, y); θ)] . (2) The ERM objective does not exploit the invariance across different domains in Dtr and could perform arbitrarily poorly on test data. We demonstrate this effect with the following simple linear example. 3.2 THE PITFALL OF ERM: A LINEAR EXAMPLE Consider a binary classification setup where data (x, y) B4 B, and a data instance is denoted x = [f1, f2, f3, f4], y. The train domains are {D1, D2}, and test domain is D3. The goal is to learn a linear model Wx + b = y, W R4, b R on the train data, such that the error on the test domain is minimized. The setup and dataset of this example is illustrated in Figure 2. As we can see in Figure 2, f1 is the invariant feature in this dataset, since the correlation between f1 and y is stable across different domains. The relationships between y and f2, f3 and f4 changes for D1, D2, D3, making them the spurious features. Importantly, if we consider one domain only, the spurious features f2, f3 and f4 are a more accurate indicator of the label than the invariant feature f1. For instance, using f2 to predict y can give 97% accuracy on D1, while using f1 only achieves 93% accuracy. Table 1: Performance comparison on the linear dataset. Method train acc. test acc. W b ERM 97% 57% [2.8, 3.3, 3.3, 0.0] 2.7 IDGM 93% 93% [0.4, 0.2, 0.2, 0.0] 0.4 Fish 93% 93% [0.4, 0.2, 0.2, 0.0] 0.4 The performance of ERM on this simple example is shown in Table 1 (first row). From the trained parameters W and b, we see that the model places most of its weights on spurious features f2 and f3. While this achieves the highest train accuracy (97%), the model cannot generalize to unseen domains and performs poorly on test accuracy (57%). 3.3 INTER-DOMAIN GRADIENT MATCHING (IDGM) To mitigate the problem with ERM, we need an objective that learns from features that are invariant across domains. Let us consider the case where the train dataset consists of S = 2 domains Dtr = {D1, D2}. Given model θ and loss function l, the expected gradients for data in the two domains is expressed as G1 = ED1 l((x, y); θ) θ , G2 = ED2 l((x, y); θ) The direction, and by extension, inner product of these gradients are of particular importance to our goal of learning invariant features. If G1 and G2 point in a similar direction, i.e. G1 G2 > 0, taking a gradient step along G1 or G2 improves the model s performance on both domains, indicating that Published as a conference paper at ICLR 2022 the features learned by either gradient step are invariant across {D1, D2}. This invariance cannot be guaranteed if G1 and G2 are pointing in opposite directions, i.e. G1 G2 0. To exploit this observation, we propose to maximize the gradient inner product (GIP) to align the gradient direction across domains. The intended effect is to find weights such that the input-output correspondence is as close as possible across domains. We name our objective inter-domain gradient matching (IDGM), and it is formed by subtracting the inner product of gradients between domains b G from the original ERM objective. For the general case where S 2, we can write Lidgm = Lerm(Dtr; θ) γ 2 S(S 1) i,j S Gi Gj GIP, denote as b G where γ is the scaling term for b G. Note that GIP can be computed in linear time as b G = || P i ||Gi||2 (ignoring the constant factor). We can also compute the stochastic estimates of Equation (4) by replacing out the expectations over the entire dataset by minibatches. We test this objective on our simple linear dataset, and report results in the second row of Table 1. Note that to avoid exploding gradient we use the normalized GIP during training. The model has lower training accuracy compared to ERM (93%), however its accuracy remains the same on the test set, much higher than ERM. The trained weights W reveal that the model assigns the largest weight to the invariant feature f1, which is desirable. The visualization in Figure 1 also confirms that by maximizing the gradient inner product, IDGM is able to focus on the feature that is common between domains, yielding better generalization performance than ERM. 3.4 OPTIMIZING IDGM WITH FISH The proposed IDGM objective, although effective, requires computing the second-order derivative of the model s parameters due to the gradient inner product term, which can be computationally prohibitive. To mitigate this, we propose a first-order algorithm named Fish1 that approximates the optimization of IDGM with inner-loop updates. In Algorithm 1 we present Fish. As a comparison, we also present direct optimization of IDGM using SGD in Algorithm 2. Algorithm 1 Fish. 1: for iterations = 1, 2, do 2: eθ θ 3: for Di permute({D1, D2, , DS}) do 4: Sample batch di Di 5: egi = Edi " l((x, y); eθ) //Grad wrt eθ 6: Update eθ eθ αegi 8: Update θ θ + ϵ(eθ θ) Algorithm 2 Direct optimization of IDGM. 1: for iterations = 1, 2, do 2: eθ θ 3: for Di permute({D1, D2, , DS}) do 4: Sample batch di Di 5: gi = Edi " l((x, y); θ) //Grad wrt θ s=1 gs, bg = GIP (batch) z }| { i,j S gi gj 8: Update θ θ ϵ ( g γ( bg/ θ)) 9: end for Fish performs S inner-loop (l3-l7) update steps with learning rate α on a clone of the original model eθ, and each update uses a minibatch di from the domain selected in step i. Subsequently, θ is updated by a weighted difference between the cloned model and the original model ϵ(eθ θ). To see why Fish is an approximation to directly optimizing IDGM, we can perform Taylor-series expansion on its update in l8, Algorithm 1. Doing so reveals two leading terms: 1) g: averaged gradients over inner-loop s minibatches (effectively the ERM gradient); 2) bg/ θ: gradient of the minibatch version of GIP. Observing l8 of Algorithm 2, we see that g and bg are actually the two gradient components used in direct optimization of IDGM. Therefore, Fish implicitly optimizes 1Following the convention of naming this style of algorithms after classes of vertebrates (animals with backbones). Published as a conference paper at ICLR 2022 IDGM by construction (up to a constant factor), avoiding the computation of second-order derivative bg/ θ. We present this more formally for the full gradient G in Theorem 3.1. Theorem 3.1 Given twice-differentiable model with parameters θ and objective l. Let us define the following: Gf = E[(θ eθ)] αS G, Fish update - αS ERM grad Gg = b G/ θ, grad of max θ ( b G) where G = 1 S PS s=1 Gs and is the full gradient of ERM. Then we have lim α 0 Gf Gg Gf Gg = 1. Note that the expectation in Gf is over the sampling of domains and minibatches. Theorem 3.1 indicates that when α is sufficiently small, if we remove the scaled ERM gradient component G from Fish s update, we are left with a term Gf that is in similar direction to the gradient of maximizing the GIP term in IDGM, which was originally second-order. Note that this approximation comes at the cost of losing direct control over the GIP scaling γ we therefore also derived a smoothed version of Fish that recovers this scaling term, however we find that changing the value of γ does not make much difference empirically. See Appendix B for more details. The proof to Theorem 3.1 can be found in Appendix A. We follow the analysis from Nichol et al. (2018), which proposes Reptile for model-agnostic meta-learning (MAML), where the relationship between inner-loop update and maximization of gradient inner product was first highlighted. Nichol et al. (2018) found the GIP term in their algorithm to be over minibatches from the same domain, which promoted within-task generalization; in Fish we construct inner-loop using minibatches over different domains it therefore instead encourages across-domain generalization. We compare the two algorithms in further details in Appendix A.1. We also train Fish on our simple linear dataset, with results in Table 1, and see it performs similarly to IDGM the model assigns the most weight to the invariant feature f1, and achieves 93% accuracy on both train and test dataset. 4 EXPERIMENTS 4.1 CDSPRITES-N Dataset We propose a simple shape-color dataset CDSPRITES-N based on the DSPRITES dataset (Matthey et al., 2017), which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. CDSPRITES-N contains N domains. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. See Figure 3 for an illustration. To construct the train split of CDSPRITES-N, we take a subset of DSPRITES that contains only 2 shapes (square and oval). We make N replicas of this subset and assign 2 colors to each, with every color corresponding to one shape (e.g. yellow block in Figure 3a, pink squares, purple oval). For the test split, we create another replica of the DSPRITES-N subset, and randomly assign one of the 2N colors in the training set to each shape in the test set. We design this dataset with CNN s texture bias in mind (Geirhos et al., 2018; Brendel and Bethge, 2019). If the value of N is small enough, the model can simply memorize the N colors that correspond to each shape, and make predictions solely based on colors, resulting in poor performance on the test set where color and shape are no longer correlated. Our dataset allows for precise control over the features that remains stable across domains and the features that change as domains change; we can also change the number of domains N easily, making it possible to examine the effect N has on the performance for domain generalization. Results We train the same model using three different objectives including Fish, dicrect optimization of IDGM and ERM on this dataset with number of domains N ranging from 5 to 50. Again, for direct optimization of IDGM, we use the normalized gradient inner product to avoid exploding gradient. Published as a conference paper at ICLR 2022 Figure 3: CDSPRITES-N visualization. Each 3x3 grid (e.g. yellow square) is one domain. Figure 4: Performance of Fish, IDGM and ERM on CDSPRITESN, with N [5, 50] We plot the average train, test accuracy for each objective over 5 runs against the number of domains N in Figure 4. We can see that the train accuracy is always 100% for all methods regardless of N (Figure 4a), while the test performance varies: Figure 4b shows that direct optimization of IDGM (red) and Fish (blue) obtain the best performances, with the test accruacy rising to over 90% when N 10 and near 100% when N 20. The predictions of ERM (yellow), on the other hand, remain nearly random on the test set up until N = 20, and reach 95% accuracy only for N 40. This experiment confirms the following: 1) the proposed IDGM objective have much stronger domain generalization capabilities compared to ERM; 2) Fish is an effective approximation of IDGM, with similar performance to its direct optimization. We also plot the gradient inner product progression of Fish vs. ERM during training in Figure 9a, showing clearly that Fish does improve the gradient inner product across domain while ERM does not; 3) we also observe during training that Fish is about 10 times faster than directly optimizing IDGM, demonstrating its computational efficiency. Datasets We evaluate our model on the WILDS benchmark (Koh et al., 2020), which contains multiple datasets that capture real-world distribution shifts across a diverse range of modalities. We report experimental results on 6 challenging datasets in WILDS, and find Fish to outperform all baselines on most tasks. A summary of the WILDS datasets can be found in Appendix C. For hyperparameters including learning rate, batch size, choice of optimizer and model architecture, we follow the exact configuration as reported in the WILDS benchmark. Importantly, we also use the same model selection strategy used in WILDS to ensure a fair comparison. See details in Appendix D. Table 2: Results on WILDS benchmark. POVERTYMAP CAMELYON17 FMOW CIVILCOMMENTS IWILDCAM AMAZON Worst-U/R Pearson r Avg. acc. (%) Worst acc. (%) Worst acc. (%) Macro F1 10-th per. acc. (%) Fish 0.30 ( 1e-2) 74.7 ( 7.1) 34.6 ( 0.18) 75.3 ( 0.6) 22.0 ( 0.0) 53.3 ( 0.0) IRM 0.43 ( 7e-2) 64.2 ( 8.1) 30.0 ( 1.37) 66.3 ( 2.1) 15.1 ( 4.9) 52.4 ( 0.8) Coral 0.44 ( 6e-2) 59.5 ( 7.7) 31.7 ( 1.24) 65.6 ( 1.3) 32.8( 0.1) 52.9 ( 0.8) Reweighted - - - 69.2 ( 0.9) - 52.0 ( 0.0) Group DRO 0.39 ( 6e-2) 68.4 ( 7.3) 30.8 ( 0.81) 70.0 ( 2.0) 23.9 ( 2.1) 53.3 ( 0.0) ERM 0.45 ( 6e-2) 70.3 ( 6.4) 32.3 ( 1.25) 56.0 ( 3.6) 31.0 ( 1.3) 53.8 ( 0.8) ERM (ours) 0.29 ( 1e-2) 70.5 ( 12.1) 30.9 ( 1.53) 58.1 ( 1.7) 25.1 ( 0.2) 53.3 ( 0.8) Results See a summary of results in Table 2, where we use the metrics recommended in WILDS for each dataset. Again, following practices in WILDS, all results are reported over 3 random seed runs, apart from CAMELYON17 which uses 10 random seeds and CIVILCOMMENTS which uses 5. We included additional results as well as a in-depth discussion on each dataset in Appendix C, and an ablation studies on Fish s hyperparameters in Appendix F and Appendix E. We make the following observations: 1. Strong performance across datasets: Considering results on all 6 datasets, Fish is the best performing algorithm on WILDS. It significantly outperforms baselines on 3 datasets and achieves similar level of performance to the best method on the other 3 (AMAZON and IWILDCAM). Fish s strong performance on different types of data and architectures such as RESNET (He et al., 2016), DENSENET (Huang et al., 2017) and DISTILBERT (Sanh et al., 2019) demonstrated it s capability to generalize to a diverse variety of tasks; Published as a conference paper at ICLR 2022 Table 3: Test accuracy (%) on DOMAINBED benchmark. ERM IRM Group DRO Mixup MLDG Coral MMD DANN CDANN Fish (ours) CMNIST 51.5 ( 0.1) 52.0 ( 0.1) 52.1 ( 0.0) 52.1 ( 0.2) 51.5 ( 0.1) 51.5 ( 0.1) 51.5 ( 0.2) 51.5 ( 0.3) 51.7 ( 0.1) 51.6 ( 0.1) RMNIST 98.0 ( 0.0) 97.7 ( 0.1) 98.0 ( 0.0) 98.0 ( 0.1) 97.9 ( 0.0) 98.0 ( 0.1) 97.9 ( 0.0) 97.8 ( 0.1) 97.9 ( 0.1) 98.0 ( 0.0) VLCS 77.5 ( 0.4) 78.5 ( 0.5) 76.7 ( 0.6) 77.4 ( 0.6) 77.2 ( 0.4) 78.8 ( 0.6) 77.5 ( 0.9) 78.6 ( 0.4) 77.5 ( 0.1) 77.8 ( 0.3) PACS 85.5 ( 0.2) 83.5 ( 0.8) 84.4 ( 0.8) 84.6 ( 0.6) 84.9 ( 1.0) 86.2 ( 0.3) 84.6 ( 0.5) 83.6 ( 0.4) 82.6 ( 0.9) 85.5 ( 0.3) Office Home 66.5 ( 0.3) 64.3 ( 2.2) 66.0 ( 0.7) 68.1 ( 0.3) 66.8 ( 0.6) 68.7 ( 0.3) 66.3 ( 0.1) 65.9 ( 0.6) 65.8 ( 1.3) 68.6 ( 0.4) Terra Inc 46.1 ( 1.8) 47.6 ( 0.8) 43.2 ( 1.1) 47.9 ( 0.8) 47.7 ( 0.9) 47.6 ( 1.0) 42.2 ( 1.6) 46.7 ( 0.5) 45.8 ( 1.6) 45.1 ( 1.3) Domain Net 40.9 ( 0.1) 33.9 ( 2.8) 33.3 ( 0.2) 39.2 ( 0.1) 41.2 ( 0.1) 41.5 ( 0.1) 23.4 ( 9.5) 38.3 ( 0.1) 38.3 ( 0.3) 42.7 ( 0.2) Average 66.6 65.4 64.8 66.7 66.7 67.5 63.3 66.1 65.6 67.1 2. Strong performance on different domain generalization tasks: We make special note the CIVILCOMMENTS dataset captures subpopulation shift problems, where the domains in test are a subpopulation of the domains in train, while all other WILDS datasets depicts pure domain generalization problems, where the domains in train and test are disjointed. As a result, the baseline models for CIVILCOMMENTS selected by the WILDS benchmark are different from the methods used in all other datasets, and are tailored to avoiding systematic failure on data from minority subpopulations. We see that Fish works well in this setting too without any changes or special sampling strategies (used for baselines on CIVILCOMMENTS, see more in Table 10), demonstrating it s capability to perform in different domain generalization scenarios; 3. Failure mode of domain generalization algorithms: We noticed that on IWILDCAM and AMAZON, ERM is the best algorithm, outperforming all domain generalization algorithms except for Fish on AMAZON. We believe that these domain generalization algorithms failed due to the large number of domains in these two datasets 324 for IWILDCAM and 7,676 for AMAZON. This is a common drawback of current domain generalization literature and is a direction worth exploring. 4.3 DOMAINBED Datasets While WILDS is a challenging benchmark capturing realistic distribution shift, to test our model under the synthetic-to-real transfer setting and provide more comparisons to SOTA methods, we also performed experiments on the DOMAINBED benchmark (Gulrajani and Lopez-Paz, 2020). See a summary of DOMAINBED in Appendix H. Results Following recommendations in DOMAINBED, we report results using training domain as validation set for model selection. See results in Table 3, reported over 5 random trials. Averaging the performance over all 7 datasets, Fish ranks second out of 10 domain generalization methods. It performs only marginally worse than Coral (0.1%), and is one of the three methods that performs better than ERM. This showcases Fish s effectiveness on domain generalization datasets with stronger focus to synthetic-to-real transfer, which again demonstrates its versatility and robustness on different domain generalization tasks. 4.4 ANALYSIS We show extensively through empirical evaluation that Fish is very effective for a variety domain generalization tasks. In this section, we perform analysis to validate that Fish s strong performance is due to inter-domain gradient inner product maximization. Figure 5: Gradient inner product values during the training for CDSPRITES-N (N=15). Does Fish maximize gradient inner product (GIP) empirically? In Figure 5, we plot the progression of GIP during training using different objectives. We train both Fish (blue) and ERM (yellow) on CDSPRITES-N until convergence while tracking the normalized GIP between minibatches from different domains used in each inner-loop. To ensure a fair comparison, we use the exact same sequence of data for Fish and ERM (see Appendix I for more details). From Figure 5, it is clear that during training, the normalized GIP of Fish increases, while that for ERM stays at the same value. The observations here shows that Fish is indeed effective in increasing/maintaining the level of inter-domain GIP. Published as a conference paper at ICLR 2022 Table 4: Test accuracy on four datasets, in three partitions are 1) two baselines Fish and ERM, 2) Fish with random grouping strategy (Fish, RG) and 3) ERM with domain grouping strategy (ERM, DG, ) FMOW VLCS PACS Office Home Fish (ours) 34.3 ( 0.6) 77.6 ( 0.5) 85.5 ( 0.3) 68.6 ( 0.9) ERM 31.7 ( 1.0) 77.5 ( 0.4) 85.5 ( 0.2) 66.5 ( 0.3) Fish, RG 33.4 ( 1.7) 77.7 ( 0.3) 83.9 ( 0.7) 66.5 ( 1.0) ERM, DG, prop. lr 32.1 ( 0.5) 72.7 ( 0.4) 83.2 ( 0.7) 58.5 ( 0.1) ERM, DG, 0.1 prop. lr 29.9 ( 0.7) 73.9 ( 0.6) 84.2 ( 0.4) 65.1 ( 0.1) ERM, DG, 0.01 prop. lr 29.8 ( 0.4) 74.7 ( 0.4) 84.0 ( 0.6) 63.7 ( 0.2) We conduct the same GIP tracking experiments for the WILDS datasets we studied as well to shed some lights on its efficiency see Appendix G for results. Does maximizing GIP between random batches help domain generalization? In this part of the analysis, we examine the effectiveness of another element of our algorithm the construction of minibatches. We conduct experiments where data are grouped randomly instead of by domain for the inner-loop update. By doing so, we are still maximizing the inner product between minibatches, however not strictly between domain. We therefore expect the results to be slightly worse than Fish, and the bigger the domain gap is, the more advantage Fish has against the random grouping strategy. We show the results for random grouping (Fish, RG) in Table 4. As expected, the random grouping strategy performs worse than Fish on all datasets. The experiment demonstrated that our algorithm does benefit from the domain grouping strategy, and that maximising GIP between random batches of data, while still achieving strong results, does not achieve the same domain generalization performance as Fish. Differences between Fish and ERM We have shown through empirical evaluation that Fish and ERM are sufficiently different. Most notably, they converge to completely distinct minimum in our linear toy example and CDSPRITES experiment, resulting in large decrepencies in test accuracy. However, some may notice that Fish as a first-order method shares algorithmic similarities to ERM. In fact, if we fix the meta learning rate of Fish ϵ = 1 and use domain sampling for ERM, the two algorithms are equivalent. Although the optimal ϵs found through hyperparameter search are much smaller than 1, one might wonder if we can offset the increased meta learning rate ϵ by reducing the learning rate α proportionally, and achieve similar performance to Fish using simply ERM with domain grouping. In Table 4 (ERM, DG), we verify that this is infeasible empirically. We train an ERM model with domain grouping strategy (equivalent to Fish with ϵ = 1) using 3 different learning rates: 1) prop. lr, learning rate α is lowered proportionally to keep αϵ constant; 2) 0.1 prop. lr and 3) 0.001 prop. lr which further reduces the learning rate. We see that on all datasets, all three configurations result in worse performance than ERM and Fish. This shows that one cannot improve the results of ERM to match that of Fish by simply adjusting the learning rate/adopting domain grouping strategy, and that Fish s gain on domain generalization tasks is not due to its proximity to ERM. In Appendix A.2, we also show that setting ϵ = 1 is equivalent to setting S as the total number of iterations during training, which causes the effect of maximizing GIP to diminish. This provide further support to our empirical finding in this analysis. Does maximising GIP help domain generalization? In Parascandolo et al. (2021), authors show that by masking out the gradients that have opposite signs in different domains, they are able to trade-off some learning speed for prioritizing learning the invariances. Our principle for proposing the IDGM objective is identical, except in this case we actively encourage these gradients to be of consistent signs. This unfortunately, does result in computing the costly second-order derivative, which makes it difficult for us to empirically evaluate the IDGM objective on non-toy dataset2. We believe that the large performance gain of IDGM over ERM on the linear and CDSPRITES datasets, along with the detailed discussion on the relationship between gradient sign consistency and invariant representation, is enough to justify for IDGM s ability to improve domain generalization. 2For context, training an IDGM model on VLCS dataset with batch size 32 and standard RESNET18 backbone does not fit a NVIDIA V100 (32GB) GPU. Published as a conference paper at ICLR 2022 5 CONCLUSION In this paper we presented inter-domain gradient matching (IDGM) for domain generalization. To avoid costly second-order computations, we approximated IDGM with a simple first-order algorithm, Fish. We demonstrated our algorithm s capability to learn from invariant features (as well as ERM s failure to do so) using simple datasets such as CDSPRITES-N and the linear example. We then evaluated the model s performance on WILDS and DOMAINBED, demonstrating that Fish performs well on different subgenres of domain generalization, and surpasses baseline performance on a diverse range of vision and language tasks using different architectures such as Dense Net, Res Net-50 and BERT. Our experiments can be replicated with 1500 GPU hours on NVIDIA V100. Despite its strong performance, similar to previous work on domain generalization, when the number of domains is large Fish struggles to outperform ERM. We are currently investigating approaches by which Fish can be made to scale to datasets with orders of magnitude more domains and expect to report on this improvement in our future work. Published as a conference paper at ICLR 2022 ETHICS STATEMENT We believe there are no ethical concerns within this work. None of the datasets used involve human identities, and our motivation in proposing this work in no way concerns any surveillance/military use. As with many other domain generalization methods, our algorithm aims at improving the performance of machine learning systems deployed in the wilds, which can be used in many ways that will benefit the society. M. Andrychowicz, M. Denil, S. Gomez, M. W. Hoffman, D. Pfau, T. Schaul, B. Shillingford, and N. De Freitas. Learning to learn by gradient descent by gradient descent. ar Xiv preprint ar Xiv:1606.04474, 2016. M. Arjovsky, L. Bottou, I. Gulrajani, and D. Lopez-Paz. Invariant risk minimization. ar Xiv preprint ar Xiv:1907.02893, 2019. Y. Balaji, S. Sankaranarayanan, and R. Chellappa. Metareg: towards domain generalization using meta-regularization. In NIPS 18 Proceedings of the 32nd International Conference on Neural Information Processing Systems, volume 31, pages 1006 1016, 2018. S. Beery, G. V. Horn, and P. Perona. Recognition in terra incognita. In Proceedings of the European Conference on Computer Vision (ECCV), pages 472 489, 2018. S. Ben-David, J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. W. Vaughan. A theory of learning from different domains. Machine learning, 79(1):151 175, 2010. S. L. Blodgett, L. Green, and B. T. O Connor. Demographic dialectal variation in social media: A case study of african-american english. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages 1119 1130, 2016. W. Brendel and M. Bethge. Approximating cnns with bag-of-local-features models works surprisingly well on imagenet. In International Conference on Learning Representations (ICLR), 2019. F. M. Carlucci, A. D Innocente, S. Bucci, B. Caputo, and T. Tommasi. Domain generalization by solving jigsaw puzzles. In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 2229 2238, 2019. Q. Dou, D. C. de Castro, K. Kamnitsas, and B. Glocker. Domain generalization via model-agnostic learning of semantic features. In Advances in Neural Information Processing Systems, volume 32, pages 6450 6461, 2019. C. Fang, Y. Xu, and D. N. Rockmore. Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias. In 2013 IEEE International Conference on Computer Vision, pages 1657 1664, 2013. C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In ICML 17 Proceedings of the 34th International Conference on Machine Learning - Volume 70, pages 1126 1135, 2017. Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky. Domain-adversarial training of neural networks. The journal of machine learning research, 17(1):2096 2030, 2016. R. Geirhos, P. Rubisch, C. Michaelis, M. Bethge, F. A. Wichmann, and W. Brendel. Imagenet-trained cnns are biased towards texture; increasing shape bias improves accuracy and robustness. In International Conference on Learning Representations (ICLR), 2018. M. Ghifary, W. B. Kleijn, M. Zhang, and D. Balduzzi. Domain generalization for object recognition with multi-task autoencoders. In 2015 IEEE International Conference on Computer Vision (ICCV), pages 2551 2559, 2015. Published as a conference paper at ICLR 2022 A. Gretton, A. Smola, J. Huang, M. Schmittfull, K. Borgwardt, and B. Schölkopf. Covariate shift and local learning by distribution matching, pages 131 160. MIT Press, Cambridge, MA, USA, 2009a. A. Gretton, A. Smola, J. Huang, M. Schmittfull, K. Borgwardt, and B. Schölkopf. Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4):5, 2009b. I. Gulrajani and D. Lopez-Paz. In search of lost domain generalization. ar Xiv preprint ar Xiv:2007.01434, 2020. T. B. Hashimoto, M. Srivastava, H. Namkoong, and P. Liang. Fairness without demographics in repeated loss minimization. In International Conference on Machine Learning, pages 1929 1938, 2018. K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770 778, 2016. W. Hu, G. Niu, I. Sato, and M. Sugiyama. Does distributionally robust supervised learning give robust classifiers. In International Conference on Machine Learning, pages 2029 2037, 2018. W. Hu, G. Niu, I. Sato, and M. Sugiyama. Does distributionally robust supervised learning give robust classifiers? In International Conference on Machine Learning, pages 2029 2037. PMLR, 2018. G. Huang, Z. Liu, L. van der Maaten, and K. Q. Weinberger. Densely connected convolutional networks. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 2261 2269, 2017. Z. Huang, H. Wang, E. P. Xing, and D. Huang. Self-challenging improves cross-domain generalization. In European Conference on Computer Vision, pages 124 140, 2020. M. Ilse, J. Tomczak, and P. Forré. Selecting data augmentation for simulating interventions. In ICML 2021: 38th International Conference on Machine Learning, pages 4555 4562, 2021. P. W. Koh, S. Sagawa, H. Marklund, S. M. Xie, M. Zhang, A. Balsubramani, W. Hu, M. Yasunaga, R. L. Phillips, S. Beery, J. Leskovec, A. Kundaje, E. Pierson, S. Levine, C. Finn, and P. Liang. Wilds: A benchmark of in-the-wild distribution shifts. ar Xiv preprint ar Xiv:2012.07421, 2020. M. Koyama and S. Yamaguchi. Out-of-distribution generalization with maximal invariant predictor. In ar Xiv e-prints, 2021. D. Li, Y. Yang, Y.-Z. Song, and T. M. Hospedales. Deeper, broader and artier domain generalization. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 5543 5551, 2017. D. Li, Y. Yang, Y.-Z. Song, and T. Hospedales. Learning to generalize: Meta-learning for domain generalization. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 32, 2018a. D. Li, J. Zhang, Y. Yang, C. Liu, Y.-Z. Song, and T. Hospedales. Episodic training for domain generalization. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 1446 1455, 2019. D. Li, Y. Yang, Y.-Z. Song, and T. M. Hospedales. Sequential learning for domain generalization. In European Conference on Computer Vision, pages 603 619, 2020. H. Li, S. J. Pan, S. Wang, and A. C. Kot. Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 5400 5409, 2018b. Y. Li, X. Tian, M. Gong, Y. Liu, T. Liu, K. Zhang, and D. Tao. Deep domain generalization via conditional invariant adversarial networks. In Proceedings of the European Conference on Computer Vision (ECCV), pages 647 663, 2018. D. Lopez-Paz and M. Ranzato. Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, volume 30, pages 6467 6476, 2017. Published as a conference paper at ICLR 2022 L. Matthey, I. Higgins, D. Hassabis, and A. Lerchner. dsprites: Disentanglement testing sprites dataset. https://github.com/deepmind/dsprites-dataset/, 2017. S. Motiian, M. Piccirilli, D. A. Adjeroh, and G. Doretto. Unified deep supervised domain adaptation and generalization. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 5716 5726, 2017. A. Nichol, J. Achiam, and J. Schulman. On first-order meta-learning algorithms. ar Xiv: Learning, 2018. G. Parascandolo, A. Neitz, A. Orvieto, L. Gresele, and B. Schölkopf. Learning explanations that are hard to vary. In ICLR 2021: The Ninth International Conference on Learning Representations, 2021. X. Peng, Q. Bai, X. Xia, Z. Huang, K. Saenko, and B. Wang. Moment matching for multi-source domain adaptation. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 1406 1415, 2019. M. Ren, W. Zeng, B. Yang, and R. Urtasun. Learning to reweight examples for robust deep learning. In International Conference on Machine Learning, pages 4334 4343, 2018. M. Rojas-Carulla, B. Schölkopf, R. Turner, and J. Peters. A causal perspective on domain adaptation. stat, 1050:19, 2015. S. Sagawa, P. W. Koh, T. B. Hashimoto, and P. Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. ar Xiv preprint ar Xiv:1911.08731, 2019. S. Sagawa, P. W. Koh, T. B. Hashimoto, and P. Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. ar Xiv preprint ar Xiv:1911.08731, 2019. V. Sanh, L. Debut, J. Chaumond, and T. Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. ar Xiv preprint ar Xiv:1910.01108, 2019. S. Seo, Y. Suh, D. Kim, G. Kim, J. Han, and B. Han. Learning to optimize domain specific normalization for domain generalization. In European Conference on Computer Vision, pages 68 83, 2019. P. Stock and M. Cisse. Convnets and imagenet beyond accuracy: Explanations, bias detection, adversarial examples and model criticism. ar Xiv preprint ar Xiv:1711.11443, 2017. B. Sun and K. Saenko. Deep coral: Correlation alignment for deep domain adaptation. In European conference on computer vision, pages 443 450. Springer, 2016. A. Tarvainen and H. Valpola. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In ICLR (Workshop), 2017. R. Tatman. Gender and dialect bias in youtube s automatic captions. In Proceedings of the First ACL Workshop on Ethics in Natural Language Processing, pages 53 59, 2017. S. Thrun and L. Pratt, editors. Learning to Learn. Kluwer Academic Publishers, USA, 1998. ISBN 0792380479. A. Torralba and A. A. Efros. Unbiased look at dataset bias. In Proceedings of the 2011 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 11, page 1521 1528, USA, 2011. IEEE Computer Society. ISBN 9781457703942. doi: 10.1109/CVPR.2011.5995347. URL https://doi.org/10.1109/CVPR.2011.5995347. L. van der Maaten and G. Hinton. Visualizing data using t-sne. Journal of Machine Learning Research, 9(86):2579 2605, 2008. H. Venkateswara, J. Eusebio, S. Chakraborty, and S. Panchanathan. Deep hashing network for unsupervised domain adaptation. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 5385 5394, 2017. Published as a conference paper at ICLR 2022 R. Volpi and V. Murino. Addressing model vulnerability to distributional shifts over image transformation sets. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 7980 7989, 2019. H. Wang, S. Ge, Z. C. Lipton, and E. P. Xing. Learning robust global representations by penalizing local predictive power. In Advances in Neural Information Processing Systems, volume 32, pages 10506 10518, 2019. S. Yan, H. Song, N. Li, L. Zou, and L. Ren. Improve unsupervised domain adaptation with mixup training. ar Xiv preprint ar Xiv:2001.00677, 2020. K. Zhou, Y. Yang, T. M. Hospedales, and T. Xiang. Learning to generate novel domains for domain generalization. In European Conference on Computer Vision, pages 561 578, 2020. Published as a conference paper at ICLR 2022 A TAYLOR EXPANSION OF REPTILE AND FISH INNER-LOOP UPDATE In this section we provide proof to Theorem 3.1. We reproduce and adapt the proof from Nichol et al. (2018) in the context of Fish, for completeness. We demonstrate that when the inner-loop learning rate α is small, the direction of Gf aligns with that of Gg, where Gf = E h (θ θ) i αS G, (5) Gg = b G/ θ, (6) Expanding Gg Gg is the gradient of maximizing the gradient inner product (GIP). Gg = 2 S(S 1) Expanding Gf To write out Gf, we need to derive the gradient update of Fish, θ θ. Let us first define some notations. For each inner-loop with S steps of gradient updates, we assume a loss functions l as well as a sequence of inputs {di}S i=1, where di := {xb, yb}B b=1 denotes a minibatch at step i randomly drawn from one of the available domains in {D1, , DS}. For reasons that will become clearer later, take extra note that the subscript i here denotes the index of step, rather than the index of domain. We also define the following: l((x, y); θi) (gradient at step i, wrt θi) (8) θi+1 = θi αegi (sequence of parameters) (9) l((x, y); θ1) (gradient at step i, wrt θ1) (10) 2l((x, y); θ1) (Hessian at initial point) (11) In the following analysis we omit the expectation Edi and input (x, y) to l and instead denote the loss at step i as li. Performing second-order Taylor approximation to egi yields: egi = l i(θi) (12) = l i(θ1) + l i (θ1)(θi θ1) + O( θi θ1 2) | {z } =O(α2) = gi + Hi(θi θ1) + O(α2) (14) j=1 egj + O(α2). (15) Applying first-order Taylor approximation to egj gives us egj = gj + O(α), (16) plugging this back to Equation (15) yields: egi = gi αHi j=1 gj + O(α2). (17) For simplicity reason, let us consider performing two steps in inner-loop updates, i.e. S = 2. We can then write the gradient of Fish θ θ as θ θ = α(eg1 + eg2) (18) = α(g1 + g2 | {z } ) α2 H2g1 | {z } +O(α3). (19) Published as a conference paper at ICLR 2022 Furthermore, taking the expectation of θ θ under minibatch sampling gives us (assuming independence between g1 and g2) 1 = E1,2 [g1 + g2] = G1 + G2 2 = E1,2 [H2g1] = E1,2 [H1g2] (interchanging indices) 2 E1,2 [H2g1 + H1g2] (averaging last two eqs) Note that the only reason we can interchange the indices in 2 is because the subscripts represent steps in the inner loop rather than index of domains. Plugging 1 , 2 in Equation (19) yields: E[θ θ] = α(G1 + G2) α2 θ1 + O(α3) (20) We can also expand this to the general case where S 2: θ1 + O(α3). (21) The second term in Equation (5) is G, which is the full gradient of ERM defined as follow: s=1 Gs. (22) Plugging Equation (21) and Equation (22) to Equation (5) yields Gf = E[θ θ] αS G (23) θ1 Gi Gj (24) Comparing Equation (7) to Equation (24), we have: lim α 0 Gf Gg Gf Gg = 1. A.1 FISH AND REPTILE: DIFFERENCES AND CONNECTIONS As we introduced, our algorithm Fish is inspired by Reptile, a model agnostic meta-learning (MAML) algorithm. Meta-learning aims at reducing the sample complexity of new, unseen tasks. A popular school of thinking in meta-learning is MAML, first proposed in Finn et al. (2017); Andrychowicz et al. (2016). The key idea is to backpropagate through gradient descent itself to learn representations that can be easily adapted to unseen tasks. There are close connections between meta-learning (Thrun and Pratt, 1998) and (multi-source) domain adaptation. In fact, there are a few works in domain generalization that are inspired by the meta-learning principles, such as Li et al. (2018a); Balaji et al. (2018); Li et al. (2019); Dou et al. (2019). In Ren et al. (2018), we also see the leveraging of gradient inner product in meta-learning, where it is used to determine the importance weight of training examples. Even though meta learning and domain generalization both study N-way, K-shot problems, there are some distinct differences that set them apart. The most prominent one is that in meta learning, some examples in the test dataset will be made available at test time (K > 0), while in domain Published as a conference paper at ICLR 2022 Algorithm 3 Black fonts denote steps used in both algorithms, colored fonts are steps unique to Fish or Reptile. 1: for i = 1, 2, do 2: θ θ 3: Sample task Dt {D1, , DT } 4: for s {1, , S} or Dt {D1, , DT } do 5: Sample batch dt Dt 6: gt = L(dt; θ)/ θ 7: Update θ θ αgt 8: end for 9: Update θ θ + ϵ( θ θ) 10: end for Figure 6: Performance on CDSPRITES-N, with N [5, 50] generalization no example in the test dataset is seen by the model (K = 0); another important difference is that while domain generalization aims to train models that perform well on an unseen distribution of the same task, meta-learning assumes multiple tasks and requires the model to quickly learn an unseen task using only K examples. Due to these differences, it does not make sense in general to use MAML framework in domain generalization. As it turns out however, the idea of aligning gradients to improve generalization is relevant to both methods The fundamental difference here that MAML algorithms such as Reptile aligns the gradients between batches from the same task Nichol et al. (2018), while Fish aligns those between batches from different tasks. To see how this is ahiceved, let us have a look at the algorithmic comparison between Fish (blue) and Reptile (green) in Algorithm 3. As we can see, the key difference between the algorithm of Fish and Reptile is that Reptile performs its inner-loop using minibatches from the same task, while Fish uses minibatches from different tasks (l4-8). Based on the analysis in Nichol et al. (2018) (which we reproduce in Appendix A), this is why Reptile maximizes the within-task gradient inner products and Fish maximizes the across-task gradient inner products. A natural question to ask here is how does this affect their empirical performance? In Figure 6, we show the train and test performance of Fish (blue) and Reptile (green) on CDSPRITES-N. We can see that despite the algorithmic similarity between Fish and Reptile, the two methods behave very differently on this domain generalization task: while Fish s test accuracy goes to 100% at N = 10, Reptile s test performance is always 50% regardless of N. Moreover, we observe a dip in Reptile s training performance early on, with the accuracy plateaus at 56% when N > 20. Reptile s poor performance on this dataset is to be expected since its inner-loop is designed to encourage within-domain generalization, which is not helpful for learning what s invariant across domains. A.2 WHEN ϵ = 1 When we set the meta learning rate of Fish ϵ = 1, Fish reduces to ERM with a data-loader that samples from only one domain in each minibatch. Note that with this change, the inner-loop step S now represents the total training iterations, and as a result it increases from a constant value ( 10 in our experiments) to orders of magnitude larger (up to 105 for some datasets that we use). We Published as a conference paper at ICLR 2022 demonstrate here that for this reason, the effect of maximizing inter-domain gradient inner product is significantly less prominent when ϵ = 1. To see this, let us revisit Equation (21), where we demonstrate that the expectation to Fish s update E[θ θ] can be written as the sum of the following three terms via Taylor series expansion, θ1 | {z } (2) where (1) is the sum of gradients over S steps of SGD updates (equivalent to ERM), (2) the gradient inner product over S steps of gradients and (3) is the higher order terms of Taylor series expansion. For both Fish and Reptile, the higher order terms are ignored under two conditions: 1) constant/noninfinite inner-loop step S and 2) small learning rate α. However, when ϵ = 1, inner-loop step S . As a result, the higher order terms cannot be ignored, and we can no longer conclude that gradient inner product plays an important rule in the model s updates. B SMOOTHFISH: A MORE GENERAL ALGORITHM B.1 DERIVATION We conclude in Appendix A that a component of Fish s update Gf = E[θ θ] αS G is in the same direction as the gradient of GIP, Gg. It is therefore possible to have explicit control over the scaling of the GIP component in Fish, similar to the original IDGM objective, by writing the following: Gsm = αS G + γ E[θ θ] αS G . (26) By introducing the scaling term γ, we have better control on how much the objective focus on inner product vs average gradient. Note that γ = 1 recovers the original Fish gradient, and when γ = 0 the gradient Gsm is equivalent to ERM s gradient with learning rate αS. We name the resulting algorithm Smooth Fish. See Algorithm 4. B.2 RESULTS We run experiments on 4 datasets in WILDS using Smooth Fish, with γ ranging in [0, 0.1, 0.2, 0.5, 0.8, 1, 2, 10]. Note that when γ = 0, Smooth Fish is equivalent to ERM and when γ = 1 it is equivalent to Fish. See results in Figure 7. The other hyperparameters including α, meta steps, ϵ used here are the same as the ones used in our main experiments section. Published as a conference paper at ICLR 2022 Algorithm 4 Smoothed version of Fish, which allows to get approximate gradients for the general form of Equation (4). 1: for iterations = 1, 2, do 2: eθ θ 3: for Di permute({D1, D2, , DS}) do 4: Sample batch di Di 5: gi = Edi l((x, y); θ) //Grad wrt θ 6: egi = Edi " l((x, y); eθ) //Grad wrt eθ 7: Update eθ eθ αegi 8: end for s=1 gi, gsm = αS g + γ (eθ θ) αS g 10: Update θ θ + ϵgsm 11: end for (a) CAMELYON17 (b) CIVILCOMMENTS (d) POVERTY Figure 7: Results on WILDS using Smooth Fish with γ ranging from 0 to 10. C DISCUSSIONS AND RESULTS ON WILDS We provide a more detailed summary of each dataset in Table 5. Some entries in # Domains are omitted because the domains in each split overlap. Note that in this paper we report the results on WILDS v1 the benchmark has been updated since with slightly different dataset splits. We are currently working on updating our results to v2 of WILDS. Published as a conference paper at ICLR 2022 Table 5: Details of the 6 WILDS datasets we experimented on. Dataset Domains (# domains) Data (x) Target (y) # Examples # Domains train val test train val test FMOW Time (16), Regions (5) Satellite images Land use (62 classes) 76,863 19,915 22,108 11, - 3, - 2, - POVERTY Countries (23), Urban/rural (2) Satellite images Asset (real valued) 10,000 4,000 4,000 13, - 5, - 5, - CAMELYON17 Hospitals (5) Tissue slides Tumor (2 classes) 302,436 34,904 85,054 3 1 1 CIVILCOMMENTS Demographics (8) Online comments Toxicity (2 classes) 269,038 45,180 133,782 - - - IWILDCAM2020 Trap locations (324) Photos Animal species (186 classes) 142,202 20,784 38,943 245 32 47 AMAZON Reviewers (7,676) Product reviews Star rating (5 classes) 1,000,124 100,050 100,050 5,008 1,334 1,334 Table 6: Results on POVERTYMAP-WILDS. Method Val. Worst-U/R r Test Worst-U/R r Fish 0.47 ( 0.01) 0.30 ( 0.01) IRM 0.53 ( 0.05) 0.43 ( 0.07) ERM 0.80 ( 0.04) 0.78 ( 0.04) ERM (ours) 0.48 ( 0.11) 0.29 ( 0.02) Coral 0.51 ( 0.06) 0.45 ( 0.06) Table 7: Results on CAMELYON17-WILDS. Method Val. Accuracy (%) Test Accuracy (%) Fish 83.9 ( 1.2) 74.7 ( 7.1) ERM 84.9 ( 3.1) 70.3 ( 6.4) ERM (ours) 84.1 ( 2.4) 70.5 ( 12.1) IRM 86.2 ( 1.4) 64.2 ( 8.1) Coral 86.2 ( 1.4) 59.5 ( 7.7) C.1 POVERTYMAP-WILDS Task: Asset index prediction (real-valued). Domains: 23 countries The task is to predict the real-valued asset wealth index of an area, given its satellite imagery. Since the number of domains considered here is large (23 countries), instead of looping over all S domains in each inner-loop, we sample N << S domains in each iteration and perform inner-loop updates using minibatches from these domains only to speed up computation. For this dataset we choose N = 5 by hyper-parameter search. Evalutaion: Worst-U/R Pearson Correlation (r). Following the practice in WILDS benchmark, we compare the results by computing the worst region Pearson correlation (r) between the predicted and ground-truth asset index over 3 random seed runs. Results: We train the model using a Res Net-18 (He et al., 2016) backbone. See Table 6. We see that Fish obtains the highest test performance, with the same validation performance as the best baseline. The performance is more stable between validation and test, and the standard deviation is smaller than for the baselines. We also report the results of ERM models trained in our environment as ERM (ours) , which shows similar performance to the canonical results reported in the WILDS benchmark itself ( ERM ). C.2 CAMELYON17-WILDS Task: Tumor detection (2 classes). Domains: 5 hospitals The CAMELYON17-WILDS dataset contains 450,000 lymph-node scans from 5 hospitals. Due to the size of the dataset, instead of training with Fish from scratch, we pre-train the model with ERM using the recommended hyper-parameters in Koh et al. (2020), and fine-tune with Fish. For this dataset, we find that Fish performs the best when starting from a pretrained model that has not yet converged, achieving much higher accuracy than the ERM model. we provide an ablation study on this in Appendix E. Evaluation: Average accuracy. We evaluate the average accuracy of this binary classification task. Following Koh et al. (2020), we show the mean and standard deviation of results over 10 random seeds runs. The number of random seeds required here is greater than other WILDS datasets due to the large variance observed in results. Note that these random seeds are not only applied during the fine-tuning stage, but also to the pretrained models to ensure a fair comparison. Results: Following the practice in WILDS, we adopt Dense Net-121 s (Huang et al., 2017) architecture for models trained on this dataset. See results in Table 7. The results show that Fish significantly outperforms all baselines its test accuracy surpasses the best performing baseline by 6%. Also note that for all other baselines, there is a large gap between validation and test accuracy (11% 27%). This is because WILDS chose the hospital that is the most Published as a conference paper at ICLR 2022 difficult to generalize to as the test split to make the task more challenging. Surprisingly, as we can observe in Table 7, the discrepancy between test and validation accuracy of Fish is quite small (3%). The fact that it is able to achieve a similar level of accuracy on the worst-performing domain further demonstrates that Fish does not rely on domain-specific information, and instead makes predictions using the invariant features across domains. To demonstrate that Fish s strong performance on CAMELYON17 s selected test set is not merely coincidental, we randomly chose 3 ways of assigning the 5 domains as train/test/val (3/1/1 domains) splits. See results collected over 10 random seeds in Table 8. Notably, on these different splits of CAMELYON17 Fish still on average outperforms ERM (with only 2 exceptions), demonstrating that Fish can generally achieve good performance on the Camelyon17 dataset. Table 8: Results on CAMELYON17-WILDS with shuffled splits. Train/Val/Test ID Val, ERM (%) Val, Fish (%) Test, ERM (%) Test, Fish (%) 012/3/4 71.3 ( 6.3) 71.9 ( 5.1) 80.9 ( 8.6) 88.9 ( 6.9) 124/0/3 72.3 ( 5.3) 74.1 ( 4.3) 63.4 ( 6.6) 65.9 ( 3.9) 234/0/1 83.9 ( 3.8) 84.1 ( 2.3) 75.6 ( 2.8) 73.9 ( 3.5) 034/1/2 (original) 84.3 ( 2.1) 82.5 ( 1.2) 73.3 ( 9.0) 79.5 ( 6.0) Table 9: Results on FMOW-WILDS. Method Val. Accuracy (%) Test Accuracy (%) Average Worst Average Worst Fish 57.8 ( 0.15) 49.5 ( 2.34) 51.8 ( 0.32) 34.6 ( 0.18) ERM 59.5 ( 0.37) 48.9 ( 0.62) 53.0 ( 0.55) 32.3 ( 1.25) ERM (ours) 59.9 ( 0.22) 47.1 ( 1.21) 52.9 ( 0.18) 30.9 ( 1.53) IRM 57.4 ( 0.37) 47.5 ( 1.57) 50.8 ( 0.13) 30.0 ( 1.37) Coral 56.9 ( 0.25) 47.1 ( 0.43) 50.5 ( 0.36) 31.7 ( 1.24) Table 10: Results on CIVILCOMMENTS-WILDS. Method Val. Accuracy (%) Test Accuracy (%) Average Worst Average Worst Fish 88.8 ( 0.6) 70.5 ( 1.0) 89.4 ( 0.2) 75.3 ( 0.6) Group DRO 89.6 ( 0.3) 68.7 ( 1.0) 89.4 ( 0.3) 70.4 ( 2.1) Reweighted 89.1 ( 0.3) 67.9 ( 1.2) 88.9 ( 0.3) 67.3 ( 0.1) ERM 92.3 ( 0.6) 53.6 ( 0.7) 92.2 ( 0.6) 58.0 ( 1.2) ERM (ours) 92.1 ( 0.5) 54.1 ( 0.4) 92.5 ( 0.3) 58.1 ( 1.7) C.3 FMOW-WILDS Task: Infrastructure classification (62 classes). Domains: 80 (16 years x 5 regions) Similar to CAMELYON17-WILDS, since the number of domains is large, we sample N = 5 domains for each inner-loop. To speed up computation, we also use a pretrained ERM model and fine-tune with Fish; different from Appendix C.2, we find the best-performing models are acquired when using converged pretrained models (see details in Appendix E). Evaluation: Average & worst-region accuracies. Following WILDS, the average accuracy evaluates the model s ability to generalize over years, and the worst-region accuracy measures the model s performance across regions under a time shift. We report results using 3 random seeds. Results: Following Koh et al. (2020), we use a Dense Net-121 pretrained on Image Net for this dataset. Results in Table 9 show that Fish has the highest worst-region accuracy on both test and validation sets. It ranks second in terms of average accuracy, right after ERM. Again, Fish s performance is notably stable with the smallest standard deviation across all metrics compared to baselines. C.4 CIVILCOMMENTS-WILDS Task: Toxicity detection in online comments (2 classes). Domains: 8 demographic identities. The CIVILCOMMENTS-WILDS contains 450,000 comments collected from online articles, each annotated for toxicity and the mentioning of demographic identities. Again, we use ERM pre-trained model to speed up computation, and sample N = 5 domains for each inner-loop. Evaluation: Worst-group accuracy. To study the bias of annotating comments that mentions demographic groups as toxic, the WILDS benchmark proposes to evaluate the model s performance Published as a conference paper at ICLR 2022 by doing the following: 1) Further separate each of the 8 demographic identities into 2 groups by toxicity for example, separate black into black, toxic and black, not toxic; 2) measure the accuracies of these 8 2 = 16 groups and use the lowest accuracy as the final evaluation of the model. This metric is equivalent to computing the sensitivity and specificity of the classifier on each demographic identity, and reporting the worse of the two metrics over all domains. Good performance on the group with the worst accuracy implies that the model does not tend to use demographic identity as an indicator of toxic comments. Again, following Koh et al. (2020) we report results of 3 random seed runs. Results: We compare results to the baselines used in the WILDS benchmark over 3 random seed runs in Table 10. All models are trained using Distil BERT (Sanh et al., 2019). The results show that Fish outperforms the best baseline by 4% and 7% on the test and validation set s worst-group accuracy respectively, and is competitive in terms of average accuracy with ERM (within standard deviation). The strong performance of Fish on worst-group accuracy suggests that the model relies the least on demographic identity as an indicator of toxic comments compared to other baselines. ERM, on the other hand, has the highest average accuracy and the lowest worst-group accuracy. This indicates that it achieves good average performance by leveraging the spurious correlation between toxic comments and the mention of certain demographic groups. Note that different from all other datasets in WILDS that focus on pure domain generalization (i.e, no overlap between domains in train and test splits), CIVILCOMMENTS-WILDS is a subpopulation shift problem, where the domains in test are a subpopulation of the domains in train. As a result, the baseline models used in WILDS for this dataset are different from the methods used in all other datasets, and are tailored to avoiding systematic failure on data from minority subpopulations. Fish works well in this setting too without any changes or special sampling strategies (such as and + in Table 10). This further demonstrates the good performance of our algorithm on different domain generalization scenarios. C.5 IWILDCAM-WILDS Task: Animal species (186 classes). Domains: 324 camera locations. The dataset consists of over 200,000 photos of animal in the wild, using stationary cameras across 324 locations. Classifying animal species from these heat or motion-activated photos is especially challenging: methods can easily rely on the background information of photos from the same camera setup. Fish models are pretrained with ERM till convergence, and for each inner loop we sample from N = 10 domains. Evaluation: Macro F1 score. Across the 186 class labels, we report average accuracy and both weighted and macro F1 scores (F1-w and F1-m, respectively, in Table 11). We run 3 random seeds for each model. Results: All models reported in Table 7 are trained using a Res Net-50. We find Fish to outperform baselines on both test accuracy and weighted F1, with a 1% improvement on both metrics over the best performing model (ERM). However, this comes at the cost of lower macro F1 score, where Fish performs 1% worse than ERM models that we trained and 3% than the ERM reported in WILDS. This suggests that Fish is less good at classifying rarer species, however the overall accuracy on the test dataset is improved. Although Fish did not outperform the ERM baseline on the primary evaluation metric proposed in Koh et al. (2020), we found the improvement of Fish in both accuracy and weighted F1 to be robust across a range of hyperparameters. See more details on this in Appendix D. C.6 AMAZON-WILDS Task: Sentiment analysis (5 classes). Domains: 7,676 Amazon reviewers. The dataset contains 1.4 million customer reviews on Amazon from 7,676 customers, and the task is to predict the score (1-5 stars) given the review. Similarly, we pretrained the model with ERM Published as a conference paper at ICLR 2022 till convergence, and due to the large number of domains (S = 5008 in train) we sample N = 5 reviewers for each inner loop. Evaluation: 10th percentile accuracy. Reporting the accuracy of the 10th percentile reviewer helps us assess whether the model performance is consistent across different reviewers. The results in Table 12 are reported over 3 random seeds. Results: The model is trained using DISTILBERT (Sanh et al., 2019) backbone. While Fish has lower average accuracy compared to ERM, its 10th percentile accuracy matches that of ERM, outperforming all other baselines. Table 11: Results on IWILDCAM-WILDS. Method Test ID Macro F1 (%) Test ID Avg Acc (%) Test OOD Macro F1 (%) Test OOD Avg Acc (%) Fish 40.3 ( 0.6) 73.8 ( 0.1) 22.0 ( 1.8) 64.7 ( 2.6) Group DRO 37.5 ( 1.7) 71.6 ( 2.7) 23.9 ( 2.1) 72.7 ( 2.0) ERM 47.0 ( 1.4) 75.7 ( 0.3) 31.0 ( 1.3) 71.6 ( 2.5) Coral 43.5 ( 3.5) 73.7 ( 0.4) 32.8 ( 0.1) 73.3 ( 4.3) IRM 22.4 ( 7.7) 59.9 ( 8.1) 15.1 ( 4.9) 59.8 ( 3.7) Table 12: Results on AMAZON-WILDS. Method Val. Accuracy (%) Test Accuracy (%) Average 10-th per. Average 10-th per. Fish 72.5 ( 0.0) 54.0 ( 0.0) 71.7 ( 0.0) 53.3 ( 0.0) ERM 72.7 ( 0.1) 55.2 ( 0.7) 71.9 ( 0.1) 53.8 ( 0.0) IRM 71.5 ( 0.3) 54.2 ( 0.8) 70.5 ( 0.3) 52.4 ( 0.8) Reweighted 69.1 ( 0.0) 52.1 ( 0.2) 68.6 ( 0.6) 52.0 ( 0.0) D HYPERPARAMETERS In Table 13 we list the hyperparameters we used to train ERM. The same hyperparameters were used for producing ERM baseline results and as pretrained models for Fish. In val. metric we report the metric on validation set that is used for model selection, and in cut-off we specify when to stop training when using ERM to generate pretrained models. Table 13: Hyperparameters for ERM. We follow the hyperparameters used in WILDS benchmark. Note that we did not use a pretrained model for POVERTY, therefore its cut-off condition is not reported. Dataset Model Learning rate Batch size Weight decay Optimizer Val. metric Cut-off CAMELYON17 Densenet-121 1e-3 32 0 SGD acc. avg. iter 500 CIVILCOMMENTS Distil BERT 1e-5 16 0.01 Adam acc. wg. Best val. metric FMOW Densenet-121 1e-4 64 0 Adam acc. avg. Best val. metric IWILDCAM Resnet-50 1e-4 16 0 Adam F1-macro (all) Best val. metric POVERTY Resnet-18 1e-3 64 0 Adam Worst-U/R Pearson (r) - AMAZON Distil BERT 2e-6 8 0.01 Adam 10th percentile acc. - In Table 14 we list out the hyperparameters we used to train Fish. Note that we train Fish using the same model, batch size, val metric and optimizer as ERM these are not listed in Table 14 to avoid repetitions. Weight decay is always set as 0. E ABLATION STUDIES ON PRE-TRAINED MODELS In this section we perform ablation study on the convergence of pretrained ERM models. Note that the pretraining is done on the same domain generalization dataset that Fish is trained on later, not on Image Net. We study the performance of Fish with the following three configurations of pretrained ERM models: Published as a conference paper at ICLR 2022 Table 14: Hyperparameters for Fish. Dataset Group by α ϵ # domains Meta steps CAMELYON17 Hospitals 1e-3 0.01 3 3 CIVILCOMMENTS Demographics toxicity 1e-5 0.05 16 5 FMOW time regions 1e-4 0.01 80 5 IWILDCAM Trap locations 1e-4 0.01 324 10 POVERTY Countries 1e-3 0.1 23 5 AMAZON Reviewers 2e-6 0.01 7,676 5 1. Model is trained on 10% of the data (epoch 1); 2. Model is trained on 50% of the data (epoch 1); 3. Model at convergence. By comparing the results between these three settings, we demonstrate how the level of convergence affects the Fish s training performance. See results in Table 15. Note that POVERTY is excluded here because the dataset is small enough that we are able to train Fish from scratch. Table 15: Ablation study on pretrained ERM models. Model FMOW CAMELYON17 IWILDCAM CIVILCOMMENTS Test Avg Acc Test Avg Acc Test Macro F1 Test Worst Acc 10% data 21.7 ( 2.5) 79.1 ( 12.3) 13.7 ( 0.5) 71.8 ( 1.3) 50% data 31.0 ( 0.8) 64.6 ( 12.3) 19.0 ( 0.06) 74.2 ( 0.5) Converged 32.7 ( 1.2) 63.5 ( 8.2) 23.7 ( 0.9) 73.8 ( 1.8) We see that CIVILCOMMENTS sustain good performance using pretrained models at different convergence levels. FMOW and IWILDCAM on the other hand seems to have strong preference towards converged model, and the results worsen as the amount of data seen during training goes down. CAMELYON17 achieves the best performance when only 10% of data is seen, and the test accuracy decreases while training with models with higher level of convergence. F ABLATION STUDIES ON HYPERPARAMETERS α and ϵ We study the effect of changing Fish s inner loop learning rate α and outer loop learning rate ϵ. To make the comparisons more meaningful, we keep α ϵ constant while changing their respective values. See results in Figure 8. Meta steps N For most of the datasets we studied (all apart from CAMELYON17 where T = 3) we sample a N-sized subset of all T domains available for training (see Table 14 for T of each dataset). Here we study when N = 5, 10, 20. Table 16: Ablation study on meta steps N. N FMOW POVERTY IWILDCAM CIVILCOMMENTS Test Avg Acc Test Pearson r Test Macro F1 Test Worst Acc 5 33.0 ( 1.6) 80.3 ( 1.7) 23.7 ( 0.9) 74.3 ( 1.5) 10 32.7 ( 1.2) 80.0 ( 0.8) 23.7 ( 0.5) 73.4 ( 1.0) 20 33.3 ( 2.1) 77.7 ( 2.1) 23.7 ( 0.9) 72.6 ( 2.3) In general altering these hyperparameters don t have a huge impact on the model s performance, however it does seem thet when N = 20 the performance on some datasets (POVERTY and CIVILCOMMENTS) degrade slightly. Published as a conference paper at ICLR 2022 (a) CAMELYON17 (b) CIVILCOMMENTS (d) IWILDCAM (e) POVERTY Figure 8: Ablation studies on α and ϵ. Note that α ϵ remains constant in all experiments, and the midpoint of each plot is the hyperparameter we chose to use to report our experiment results. (a) CDSPRITES-N (b) CAMELYON17 (c) CIVILCOMMENTS (e) POVERTY (f) IWILDCAM Figure 9: Gradient inner product values during the training for CDSPRITES-N (N=15) and 5 different WILDS datasets. G TRACKING GRADIENT INNER PRODUCT In Figure 9, we demonstrate the progression of inter-domain gradient inner products during training using different objectives. We train both Fish (blue) and ERM (yellow) untill convergence while recording the normalized gradient inner products (i.e. cosine similarity) between minibatches from different domains used in each inner-loop. The gradient inner products are computed both before (dotted) and after (solid) the model s update. To ensure a fair comparison, we use the exact same sequence of data for Fish and ERM (see Appendix I for more details). Inevitably, the gradient inner product trends differently for each dataset since the data, types of domain splits and the choice of architecture are very different. In fact, the plot for CDSPRITES-N and POVERTY are significantly different from others, with a dip in gradient inner product at the beginning Published as a conference paper at ICLR 2022 of training this is because these are the two datasets that we train from scratch. On all other datasets, the gradient inner products are recorded when fine-tuning with Fish. Despite their differences, there are some important commonalities between these plots: if we compare the pre-update (dotted) to post-update (solid) curves, we can see that ERM updates often result in the decrease of gradient inner product, while for Fish it can either increase significantly (Figure 9c and Figure 9e) or at least stay at the same level (Figure 9a, Figure 9b, Figure 9d and Figure 9f). As a result of this, we can see that the post-update gradient inner product of Fish is always greater than ERM at convergence. The observations here shows that Fish is effective in increasing/maintaining the level of inter-domain gradient inner product and sheds some lights on its efficiency on the datasets we studied. H SUMMARY OF DOMAINBED DOMAINBED is a testbed for domain generalization that implements consistent experimental protocols across SOTA methods to ensure fair comparison. It contains 7 popular domain generalization datasets, including Colored MNIST (Arjovsky et al., 2019), Rotated MNIST (Ghifary et al., 2015), VLCS (Fang et al., 2013), PACS (Li et al., 2017), Office Home (Venkateswara et al., 2017), Terra Incognita (Beery et al., 2018) and Domain Net (Peng et al., 2019), and offers comparison to a variety of SOTA domain generalization methods, including IRM (Arjovsky et al., 2019), Group DRO (Hu et al., 2018; Sagawa et al., 2019), Mixup (Yan et al., 2020), MLDG (Li et al., 2018a), Coral (Sun and Saenko, 2016), MMD (Li et al., 2018b), DANN (Ganin et al., 2016) and CDANN (Li et al., 2018). I ALGORITHM FOR TRACKING GRADIENT INNER PRODUCT To make sure that the gradients we record for ERM and Fish are comparable, we use the same sequence of S-minibatches to train both algorithms. See Algorithm 6 for details. Algorithm 5 Function GIP 1: function GIP({d1, d2, , d N}, θ): 2: for dn {d1, d2, , d N} do 3: gn = l(dn; θ)/ θ 4: end for 5: g = 1 S(S 1) Pi =j i,j S gi gj 6: return g J T-SNE VISUALISATION OF LEARNED REPRESENTATIONS In this section we visualize the representation learned by Fish and ERM with t-SNE (van der Maaten and Hinton, 2008) for PACS, VLCS and Office Home datasets. See Figure 10. We can see that Fish and ERM are both capable of forming distinctive label-clusters on PACS (first row) and VLCS (second row), however with ERM we can observe within each label cluster that sub-clusters of domains are forming. This is particularly the case for the the red and yellow cluster in Figure 10a and cyan cluster in Figure 10c. The domain clustering phenomena is not observed for Fish. On the other hand, for the Office Home dataset where Fish outperforms ERM by more than 1%, we clearly see that Fish exhibits better label clustering performance than ERM. Published as a conference paper at ICLR 2022 Algorithm 6 Algorithm of collecting gradient inner product g for Fish and ERM both before and after updates. See GIP in Algorithm 5. 1: Initialize Fish θf θ, ERM θe θ 2: for i = 1, 2, do 3: //Get all minibatches 4: for Dn {D1, D2, , DN} do 5: Sample batch dn Dn 6: end for 7: //Grad Inner Prod before update 8: g F b = GIP({d1, d2, , d N}, θf) 9: g Eb = GIP({d1, d2, , d N}, θe) 10: //Fish training 11: θ θf 12: for dn {d1, d2, , d N} do 13: gn = l(dn; θ)/ θ 14: Update θ θ αgn 15: end for 16: θf θf + ϵ( θ θf) 17: //Rearrange minibatches 18: d = shuffle(concat(d1, d2, , d N)) 19: { d1, d2, , d N} = split(d) 20: //ERM training 21: for dn { d1, d2, , d N} do 22: gn = l( dn; θe)/ θe 23: Update θe θe αgn 24: end for 25: //Grad Inner Prod after update 26: g F a = GIP({d1, d2, , d N}, θf) 27: g Ea = GIP({d1, d2, , d N}, θe) 28: end for 29: Return g F b, g F a, g Eb, g Ea Published as a conference paper at ICLR 2022 (a) PACS, ERM (b) PACS, FISH (c) VLCS, ERM (d) VLCS, FISH (e) ERM, OFFICEHOME (f) FISH, OFFICEHOME Figure 10: t-SNE plot for PACS, VLCS and Office Home. Colors represent labels, markers (shape of each datapoint) represent domain.