# topologyaware_robust_optimization_for_outofdistribution_generalization__71e43c80.pdf Published as a conference paper at ICLR 2023 TOPOLOGY-AWARE ROBUST OPTIMIZATION FOR OUT-OF-DISTRIBUTION GENERALIZATION Fengchun Qiao University of Delaware fengchun@udel.edu Xi Peng University of Delaware xipeng@udel.edu Out-of-distribution (OOD) generalization is a challenging machine learning problem yet highly desirable in many high-stake applications. Existing methods suffer from overly pessimistic modeling with low generalization confidence. As generalizing to arbitrary test distributions is impossible, we hypothesize that further structure on the topology of distributions is crucial in developing strong OOD resilience. To this end, we propose topology-aware robust optimization (TRO) that seamlessly integrates distributional topology in a principled optimization framework. More specifically, TRO solves two optimization objectives: (1) Topology Learning which explores data manifold to uncover the distributional topology; (2) Learning on Topology which exploits the topology to constrain robust optimization for tightlybounded generalization risks. We theoretically demonstrate the effectiveness of our approach, and empirically show that it significantly outperforms the state of the arts in a wide range of tasks including classification, regression, and semantic segmentation. Moreover, we empirically find the data-driven distributional topology is consistent with domain knowledge, enhancing the explainability of our approach. 1 INTRODUCTION Recent years have witnessed a surge of applying machine learning (ML) in high-stake and safetycritical applications. Such applications pose an unprecedented out-of-distribution (OOD) generalization challenge: ML models are constantly exposed to unseen distributions that lie outside their training space. Despite well-documented success for interpolation, modern ML models (e.g., deep neural networks) are notoriously weak for extrapolation; a highly accurate model on average can fail catastrophically when presented with rare or unseen distributions (Arjovsky et al., 2019). For example, a flood predictor, trained with data of all 89 major flood events in the U.S. from 2000 to 2020, would erroneously predict on event Hurricane Ida in 2021. Without addressing this challenge, it is unclear when and where a model can be applied and how much risk is associated with its use. A promising solution for out-of-distribution generalization is to conduct distributionally robust optimization (DRO) (Namkoong & Duchi, 2016; Staib & Jegelka, 2019; Levy et al., 2020). DRO minimizes the worst-case expected risk over an uncertainty set of potential test distributions. The uncertainty set is typically formulated as a divergence ball surrounding the training distribution endowed with a certain distance metric such as f-divergence (Namkoong & Duchi, 2016) and Wasserstein distance (Shafieezadeh Abadeh et al., 2018). Compared to empirical risk minimization (ERM) (Vapnik, 1998) that minimizes the average loss, DRO is more robust against distributional drifts from spurious correlations, adversarial attacks, subpopulations, or naturally-occurring variation (Robey et al., 2021). However, it is non-trivial to build a realistic uncertainty set that truly approximates unseen distributions. On the one hand, to confer robustness against extensive distributional drifts, the uncertainty set has to be sufficiently large, which increases the risks of conferring implausible distributions, e.g., outliers, and thus yielding overly pessimistic models with low prediction confidence (Hu et al., 2018; Frogner et al., 2021). On the other hand, the worst-case distributions are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over worst-case rather than influential distributions would yield compromised OOD resilience. 1The source code and pre-trained models are available at: https://github.com/joffery/TRO. Published as a conference paper at ICLR 2023 As generalizing to arbitrary test distributions is impossible, we hypothesize further structure on the topology of distributions is crucial in constructing a realistic uncertainty set. More specifically, we propose topology-aware robust optimization (TRO) by integrating two optimization objectives: (1) Topology learning: We model the data distributions as many discrete groups lying on a common low-dimensional manifold, where we can explore the distributional topology by either using physical priors or measuring multiscale Earth Mover s Distance (EMD) among distributions. (2) Learning on topology: The acquired distributional topology is then exploited to construct a realistic uncertainty set, where robust optimization is constrained to bound the generalization risk within a topology graph, rather than blindly generalizing to unseen distributions. Our contributions include: 1. A new principled optimization method that seamlessly integrates topological information to develop strong OOD resilience. 2. Theoretical analysis that proves our method enjoys fast convergence for both convex and non-convex loss functions while the generalization risk is tightly bounded. 3. Empirical results in a wide range of tasks including classification, regression, and semantic segmentation that demonstrate the superior performance of our method over SOTA. 4. Data-driven distributional topology that is consistent with domain knowledge and facilitates the explainability of our approach. 2 PROBLEM FORMULATION AND PRELIMINARY WORKS The problem of out-of-distribution (OOD) generalization is defined by a pair of random variables (X, Y ) over instances x X Rd and corresponding labels y Y, following an unknown joint probability distribution P(X, Y ). The objective is to learn a predictor f F such that f(x) y for any (x, y) P(X, Y ). Here F is a function class that is model-agnostic for a prediction task. However, unlike typical supervised learning, the OOD generalization is complicated since one cannot sample directly from P(X, Y ). Instead, it is assumed that we can only measure (X, Y ) under different environmental conditions e so that data is drawn from a set of groups Eall such that (x, y) Pe(X, Y ). For example, in flood prediction, these environmental conditions denote the latent factors (e.g., stressors, precipitation, terrain, etc) that underlie different flood events. Let Etrain Eall be a finite subset of training groups (distributions), given the loss function ℓ, an OOD-resilient model f can be learned by solving a minimax optimization: R(f) := sup e Eall E(x,y) Pe(X,Y )[ℓ(f(x), y)] . (1) Intuitively, Eq. 1 aims to learn a model that minimizes the worst-case risk over the entire family Eall. It is nontrivial since we do not have access to data from any unseen distributions Etest = Eall\Etrain . Empirical Risk Minimization (ERM). Typically, classic supervised learning employs ERM (Vapnik, 1998) to find a model f that minimizes the average risk under the training distribution Ptr: min f F{R(f) := E(x,y) Ptr[ℓ(f(x), y)]}. Though proved to be effective in i.i.d. settings, models trained via ERM heavily rely on spurious correlations that do not always hold under distributional drifts (Arjovsky et al., 2019). Distributionally Robust Optimization (DRO). To develop OOD resilience, DRO (Namkoong & Duchi, 2016) minimizes the worst-case risk over an uncertainty set Q by solving: min f F{R(f) := sup Q P(Ptr) E(x,y) Q[ℓ(f(x), y)]}. (2) Here the uncertainty set Q approximates potential test distributions. It is usually formulated as a divergence ball with a radius of ρ surrounding the training distribution P (Ptr) = {Q : D (Q, Ptr) ρ} endowed with a certain distance metric D( , ) such as f-divergence (Namkoong & Duchi, 2016) or Wasserstein distance (Shafieezadeh Abadeh et al., 2018). To construct a realistic uncertainty set without being overly conservative, Group DRO is further developed to formulate the uncertainty set as the mixture of training groups (Hu et al., 2018; Sagawa et al., 2019). Despite the well-documented success, existing DRO methods suffer from critical limitations. (1) To endow robustness against a wide range of potential test distributions, the radius of the divergence ball has to be sufficiently large with high risks of containing implausible distributions; optimizing Published as a conference paper at ICLR 2023 for implausible distributions would fundamentally damage the OOD resilience by yielding overlypessimistic models with low prediction confidence. (2) The worst-case groups are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over worst-case rather than influential groups would yield compromised OOD resilience. 3 TOPOLOGY-AWARE ROBUST OPTIMIZATION We propose a new principled optimization method (TRO) to develop OOD resilience, which integrates topology and optimization via a two-phase scheme: Topology Learning and Learning on Topology. 3.1 TOPOLOGY LEARNING: EXPLORE THE DISTRIBUTIONAL TOPOLOGY Data-driven Topology Distributed Datasets t = 0 t = 20 t = 2K Multiscale Diffusion Density Estimates Physical Graph Physical-based Topology Diffusion EMD Figure 1: Overview of topology-aware distributionally robust optimization (TRO). We model the data groups Eall as many discrete distributions lying on a common low-dimensional manifold in a high-dimensional data measurement space. In such case their structure, i.e. distributional topology, can be naturally captured by a graph G = (V, E), where the entities V = e Eall Xe symbolize the groups and the edges E represent interactions among groups. The topology graph is constructed by: (1) identifying entity: we assume the entities are defined by the given group identities; and (2) uncovering interactions: we consider two scenarios to measure the connectivity between discrete distributions as illustrated in Fig 1. Physical-based distributional topology. In the scenario where the distributional adjacency information is available, we can instantly acquire the topology Gphysic by simply imposing the predefined neighborhood information. For example, to capture the similarity of weather events in the U.S., one can construct a graph where each state realizes an entity, and the physical adjacency between two states results in an edge (see Fig. 1). In this case, Gphysic functions as a physical prior to constrain the robust optimization introduced in Sec. 3.2. We empirically find Gphysic yields an improvement of 9.56% over the state of the art regarding OOD generalization reported in Sec. 5.1. Data-driven distributional topology. In the absence of Gphysic, we propose a data-driven approach to learn the topology Gdata from training data. Specifically, we embed the individual groups onto a shared data graph based on an affinity matrix of the combined data. Inspired by Leeb & Coifman (2016), such a data graph can be viewed as a discretization of an underlying Riemann closed manifold. By simulating a time-dependent diffusion process over the graph, we will obtain density estimates at multiple scales for each group, which will be used to calculate ℓ1 distances between two groups. Such multiscale ℓ1 distance has been proved to be topologically equivalent to the Earth Mover s Distance (EMD) on the manifold geodesic, but cutting down the computational complexity from O m2n3 to O(mn) between m distributions over n data points (Tong et al., 2021). We obtain the data-driven topology through three steps: (1) Data graph construction: we construct a data graph through an affinity matrix K of the combined data. K can be implemented through kernel functions (e.g., RBF kernel) which capture the similarity of data. Instead of calculating the similarity between raw data, we calculate the similarity between features extracted from an ERM-trained model as it captures spurious correlations which preserve group identity (Creager et al., 2021). Specifically, we define the affinity matrix as: Ki,j = exp f(xi) f(xj) 2/σ2 , where σ2 is the kernel scale. (2) Multiscale diffusion density estimation: to simulate the diffusion process over the graph, we obtain a Markov diffusion operator P from K. Following Coifman & Lafon (2006), we normalize the affinity matrix: M = Q 1KQ 1, where Q is a diagonal matrix and Qi,i = P j Ki,j. The diffusion operator is defined as P = D 1M, where D is a diagonal matrix and Di,i = P j Mi,j. The operator P will be used to approximate the multiscale density estimates µe for each data group Xe: µt e = 1 ne Pt1Xe, where t is the diffusion time, Pt denotes the t-th power of P, and 1Xe is the Published as a conference paper at ICLR 2023 indicator function for group e. Intuitively, Pt i,j sums the probabilities of all possible paths of length t between xi and xj. By taking multiple powers of P, µe reveals the topological structure of Xe at multiple scales. (3) Diffusion EMD measurement: we follow Tong et al. (2021) to measure the geodesic distance Wα,K (Xe, Xe ) between Xe and Xe by aggregating the ℓ1 distances between the multiscale density estimates: Wα,K (Xe, Xe ) = k=0 Tα,k (Xe) Tα,k (Xe ) 1 , (3) where α is used to balance longand short-range distances, K is the maximum scale, and Tα,k (Xe) = 2 (K k 1)α µ(2k+1) e µ(2k) e µ(2K) e , k = K Although Gdata is computationally more expensive than Gphysic, our experimental results in Sec. 5.2 indicate that optimizing with Gdata can yield improved OOD resilience. Besides, the ablation study in Sec. 5.3 also indicates that Gdata is consistent with domain knowledge and enhances the explainability of TRO. Last but not least, the data-driven method is fully differentiable, making it amenable to jointly conducting topology learning and learning on topology in an end-to-end manner. We leave this as future work. 3.2 LEARNING ON TOPOLOGY: EXPLOIT TOPOLOGY FOR ROBUST OPTIMIZATION Algorithm 1: TRO Algorithm Input: Data of Etrain, Step sizes ηθ and ηq Output: Learned model f Topology Learning: if Gphysic exists then G Gphysic else Obtain the affinity matrix K from data Q Diag P D Diag P j Mij P D 1M Obtain Gdata via Eq. 3 G Gdata end Learning on Topology: Calculate topological prior p from G while not converged do Sample (x, y) Pe(X, Y ) e Etrain Calculate R(f, q) via Eq. 5 Update θ and q via Eq. 6 end Next, we propose a principled method that integrates distributional topology to develop TRO. The key challenge is how to leverage G to construct a uncertainty set which can approximate unseen distributions with bounded generalization risk. Our main idea is to assess the group centrality of training distributions. Graph centrality is widely used in social network analysis (Newman, 2005) to measure how much information is propagated through each entity. Here we introduce group centrality to identify influential groups that are truly connected to unseen distributions, which can be calculated using graph measurements (Tian et al., 2019) such as degree, betweenness, and closeness. More specifically, we first calculate the centrality of each entity in G as a topological prior p to identify influential groups. Then, we construct the uncertainty set as an arbitrary mixture of training groups Q := {P e Etrain qe Pe | q m} where qe denotes the weight of group e, Pe is the distribution of group e, and m is a (m 1) dimensional probability simplex. Finally, we use the prior p to constrain the uncertainty set Q by solving the minimax optimization problem as: min f F{R(f, q) := max q m e Etrain qe E(x,y) Pe(X,Y )[ℓ(f(x), y)]}, s.t. D(q p) τ. (4) Intuitively, groups with high training loss and centrality will be assigned with large weights; this can tightly bound the OOD generalization risk within a topological graph. D is an arbitrary distributional distance metric. We use ℓ2 distance to implement D due to its strong convexity and simplicity. However, solving Eq. 4 often leads to a non-convex problem, wherein methods such as stochastic gradient descent (SGD) cannot guarantee constraint satisfaction (Robey et al., 2021). To address this Published as a conference paper at ICLR 2023 issue, we leverage Karush Kuhn Tucker conditions (Boyd et al., 2004) and introduce a Lagrange multiplier to convert the constrained problem into its unconstrained counterpart: min f F{R(f, q) := max q m e Etrain qe E(x,y) Pe(X,Y )[ℓ(f(x), y)] λD(q p)}, (5) where λ is the dual variable. Let θ Θ be the model parameters of f, we can solve the primal-dual problem effectively by alternatively updating: θt+1 = θt ηt θ θR(f, q), qt+1 = P m(qt + ηt q q R(f, q)), (6) where ηt θ (ηt q) is gradient descent (ascent) step size. P m(q) projects q onto simplex m for regularization. The overall algorithm of TRO is shown in Alg. 1. In Sec. 4, we show TRO enjoys fast convergence for both convex and non-convex loss functions, while the generalization risk is tightly bounded with topological constraints. We empirically demonstrate TRO achieves strong OOD resilience by striking a good balance between the worst-case and influential groups (see Sec. 5.2). Calculation of group centrality. We use betweenness centrality to measure the centrality of groups. Betweenness centrality measures how often an entity is on the shortest path between two other entities in the topology. Freeman (1977) reveals that entities with higher betweenness centrality would have more control over the topology as more information will pass through them. For physical-based topology Gphysic, we define the centrality of group e by computing the fraction of shortest paths that pass through it: cphysic e = P s Etrain,t Etest σ(s,t|e) σ(s,t) , where σ(s, t) is the number of shortest paths between groups s and t in the graph ((s, t)-paths), and σ(s, t | e) is the number of (s, t)-paths that go through group e. Intuitively, cphysic e measures how much information is propagated through e from the start (training) to the end (test). For data-driven topology Gdata, the underlying assumption is that training groups with high centrality also exert strong influence on unseen groups. Instead of sampling group pairs from two separate sets, we sample (s, t) from Etrain. The centrality is modified as: cdata e = P s,t Etrain σ(s,t|e) σ(s,t) . We use softmax function to normalize ce and the prior probability for group e Etrain is defined as: pe = exp(ce)/P e Etrain exp(ce). 4 THEORETICAL ANALYSIS 4.1 CONVERGENCE ANALYSIS In this section, we show that by choosing appropriate step sizes ηt θ and ηt q, TRO yields fast convergence rates for both convex and non-convex loss functions. We first state the assumptions of the theorems. Next, we give the convergence rate for convex loss functions in Theorem 1 and the convergence rate for non-convex loss functions in Theorem 2. Definition 1. (Lipschitz continuity) A mapping f : X Rm is G-Lipschitz continuous if for any x, y X, f(x) f(y) G x y . Definition 2. (Smoothness) A function f : X R is L-smooth if it is differentiable on X and the gradient f is L-Lipschitz continuous, i.e., f(x) f(y) L x y for all x, y X. Assumption 1. We make the following assumptions throughout the paper: Given θ, the loss function ℓ(fθ(x), y) is G-Lipschitz continuous and L-smooth with respect to x. Convex Loss. The expected number of stochastic gradient computations is utilized to estimate the convergence rate. To reach a duality gap of ϵ (Nemirovski et al., 2009), the optimal rate of convergence for solving the stochastic min-max problems is O 1/ϵ2 if it is convex-concave. The duality gap of the pair ( θ, q) is defined as maxq m R( θ, q) minθ Θ R(θ, q). In the case of strong duality, ( θ, q) is optimal iif the duality gap is zero. We show TRO achieves the optimal O 1/ϵ2 rate. Theorem 1. Consider the dual problem in Eq. 5 when the loss function is convex and Assumption 1 holds. Let Θ bounded by RΘ, E h θR(θ, q) 2 2 i ˆG2 θ, and E h q R(θ, q) 2 2 i ˆG2 q. With step sizes ηθ = 2RΘ/ ˆGθ T and ηq = 2/ ˆGq T , the output of TRO satisfies: E max q m R (θT , q) min θ Θ R (θ, q T ) 3RΘ ˆGθ + 3 ˆGq Published as a conference paper at ICLR 2023 Theorem 1 shows that our method requires T = O 1/ϵ2 iterations to reach a duality gap within ϵ. To derive the convergence rate for non-convex functions., we define ϵ-stationary points as follows: Definition 3. (ϵ-stationary point) For a differentiable function f : X R, a point x X is said to be first-order ϵ-stationary if f(x) ϵ. Nonconvex Loss. The loss function ℓ(fθ(x), y) can be nonconvex and as a result, R(θ, q) is nonconvex in θ. Following Collins et al. (2020), we define ( θ, q) is an (ϵ, δ)-stationary point of R if: θR( θ, q) 2 ϵ and R( θ, q) maxq m R( θ, q) δ, where ϵ, δ > 0. Theorem 2. If Assumption 1 holds and the loss function is bounded by B and is M-smooth, the output of Alg. 1 satisfies: E h θR (θT , q T ) 2 2 i R θ0, q0 + B Tηθ + 2ηq n B ˆGq ηθ + ηθM ˆG2 θ 2 , E [R (θT , q T )] max q m {E [R (θT , q)]} 1 ηq T ηq ˆG2 q 2 . Theorem 2 shows that our method converges in expectation to an (ϵ, δ)-stationary point of R in O(1/ϵ4) stochastic gradient evaluations. 4.2 GENERALIZATION BOUNDS In this section, we provide learning guarantees for TRO. Compared to DRO, TRO achieves a lower upper bound on the generalization risks over unseen distributions with the topological constraint. Let H denote the family of losses associated with a hypothesis set F : H = {(x, y) 7 ℓ(f(x), y) : f F}, and n = (n1, . . . , nm) denote the vector of sample sizes for all training groups. Following Mohri et al. (2019), we define weighted Rademacher complexity for any F as: Rn(H, q) = E Se Pe E σ i=1 σe,iℓ(f (xe,i) , ye,i) where e denotes group index, Se a sample of size ne, Pe the distribution of group e, and σ = (σe,i)e [m],i [ne] a collection of Rademacher variables. The minimax weighted Rademacher complexity for a subset Λ m is defined by Rn(H, Λ) = maxq Λ Rn(H, q) where n = Pm e=1 ne. Let PΛ be the distribution over the mixture of training groups and ˆPΛ be its empirical counterpart. Let the distribution of some test group be P. The learning guarantee for P is shown in Theorem 3. Theorem 3. Assume the loss function ℓis bounded by B. For any ϵ 0 and δ > 0, with probability at least 1 δ, the following inequality holds for all f F : RP (f) R ˆ PΛ(f, q) + 2Rn(H, Λ) + BD (P PΛ) + B 1 2m log |Λ| The upper bound of the generalization risk on P is mainly determined by its distance to PΛ: D (P PΛ). With the topological prior, risks on P can be tightly bounded by minimizing D (P PΛ), as long as P falls into the convex hull of training groups. We empirically verify the effectiveness of the topological prior in minimizing the generalization risks over unseen distributions (see Sec. 5). 5 EXPERIMENTS We evaluate TRO in a wide range of tasks including classification, regression, and semantic segmentation. We compare TRO with SOTA baselines on OOD generalization and conduct ablation study on the key components of TRO. Following Gulrajani & Lopez-Paz (2021), we perform model selection based on a validation set constructed from training groups only. We provide implementation details in Appendix 7.2 and results on Domain Bed (Gulrajani & Lopez-Paz, 2021) in Appendix 7.3. Baselines. We compare TRO with the following methods: (1) Empirical Risk Minimization (ERM) (Vapnik, 1998); (2) Group distributionally robust optimization (DRO) (Sagawa et al., 2019); (3) Invariant Risk Minimization (IRM) (Arjovsky et al., 2019); (4) Risk Extrapolation (REx) (Krueger et al., 2021); (5) Spectral Decoupling (SD) (Pezeshki et al., 2021). Published as a conference paper at ICLR 2023 Table 1: Accuracy (%) on DG-15 and DG-60. TRO sets the new SOTA on both DG-15 and DG-60. ERM IRM REx SD DRO TRO (physical) TRO (data) DG-15 58.00 57.87 57.22 57.56 43.22 67.56 67.89 DG-60 76.02 76.61 86.89 81.04 79.59 89.19 90.72 Figure 2: Illustration of data groups in (a) DG-15 and (b) DG-60 datasets. Group Importance of TRO Group Importance of DRO ACC: 43.22% ACC: 67.56% Figure 3: Group importance of DRO and TRO on DG-15. DRO assigns the highest weight to the worst-case group 1 which is the furthest group to the test groups while TRO focuses on the influential groups 2 , 5 , and 6 , which are truly connected to test groups. 5.1 CLASSIFICATION Datasets. DG-15 (Xu et al., 2022) is a synthetic binary classification dataset with 15 groups. Each group contains 100 data points. In this dataset, adjacent groups have similar decision boundaries. Following Xu et al. (2022), we use six connected groups as the training groups, and use others as test groups. Note that, different from Xu et al. (2022) which focuses on domain adaptation, the data of test groups are unseen in OOD generalization. DG-60 (Xu et al., 2022) is another synthetic dataset generated using the same procedure as DG-15, except that it contains 60 groups, with 6,000 data points in total. We randomly select six groups as the training groups and use others as test groups. Visualization of DG-15 and DG-60 are shown in Fig. 2 (a) and (b), respectively. Results. The results of DG-15 and DG-60 are summarized in Tab. 1. In both datasets, our method yields the highest accuracy. For DG-15, we show the detailed results of all groups in Fig. 8. We visualize the decision boundary of DG-15 and DG-60 in Appendix 7.3. Ablations study. TRO significantly improves the generalization performance by discovering influential groups. To investigate the reason why TRO outperforms DRO, we show group weights q of DRO and TRO on DG-15 in Fig. 3. DRO assigns the highest weight to group 1 which is the furthest group to test groups. Instead, TRO prioritizes influential groups 2 , 5 , and 6 which are truly connected to the test ones, yielding improved performance on unseen distributions. 5.2 REGRESSION Datasets. TPT-48 (Vose et al., 2014) contains the monthly average temperature for the 48 contiguous states in the US from 2008 to 2019. We focus on the regression task to predict the next 6 months temperature based on the previous first 6 months temperature. We consider two generalization tasks: (1) E(24) W(24): we use the 24 eastern states as training groups and the 24 western states as test groups; (2) N(24) S(24): we use the 24 northern states as training groups and the 24 southern states as test groups. Test groups one hop away from the closest training group are defined as Hop-1 test groups, those two hops away are Hop-2 test groups, and the remaining groups are Hop-3 test groups. The visualization of N(24) S(24) on TPT-48 is shown in Fig. 4 (left). Published as a conference paper at ICLR 2023 Table 2: Mean Squared Error (MSE) for both tasks E (24) W (24) and N (24) S (24) on TPT-48. TRO (data-driven topology) consistently outperforms TRO (physical-based topology) in both tasks, indicating the data-driven topology captures the distributional relation more accurately. Task Group ERM IRM REx SD DRO TRO (physical) TRO (data) E (24) W (24) Average of Hop-1 groups 1.693 1.699 1.577 1.701 1.678 1.445 1.435 Average of Hop-2 groups 1.800 1.811 1.702 1.806 1.762 1.576 1.569 Average of Hop-3 groups 1.672 1.679 1.584 1.674 1.628 1.400 1.392 Average of All test groups 1.716 1.724 1.616 1.722 1.684 1.466 1.458 N (24) S (24) Average of Hop-1 groups 1.084 1.133 0.487 1.169 0.931 0.889 0.855 Average of Hop-2 groups 1.265 1.312 0.944 1.354 1.170 0.991 0.950 Average of Hop-3 groups 1.975 2.021 2.266 2.091 2.027 1.678 1.604 Average of All test groups 1.426 1.474 1.194 1.523 1.356 1.177 1.129 Data-driven Distributionally Topology Generalization Task of North South Physical-based Distributional Topology MSE: 1.177 MSE: 1.129 Figure 4: Left: Generalization task of North South on TPT-48. Middle: Group centrality of physical-based topology. Right: Group centrality of data-driven topology. PA is identified by TRO as the influential group in physical-based topology; NY , PA , and MA are identified by TRO as influential groups in data-driven topology. Data topology yields lower MSE than physical topology. Group Importance of DRO Group Importance after Learning on Topology MSE: 1.356 MSE: 1.129 WORST-CASE GROUP Group Importance after Topology Learning INFLUENTIAL GROUP Figure 5: Group importance of DRO and TRO on the North South task. TRO significantly reduces the generalization risks by not only prioritizing the worst-case groups but also the influential ones. Results. We show the results of TPT-48 in Tab. 2. TRO yields the lowest average MSE on both tasks. We also report the average MSE of Hop-1, Hop-2, and Hop-3 test groups for both tasks. Although REx yields the lowest error on Hop-1 and Hop-2 groups in N (24) S (24), it yields the highest prediction error on Hop-3 groups. The results indicate that REx may yield compromised performance under large distributional drifts. TRO yields the best performance on Hop-3 groups, indicating its strong generalization capability under large distributional drifts. Ablations study. (1) Data-driven topology yields better performance than physical-based topology. We show group centrality of both physical and data topology on the task of North South in Fig. 4. PA is identified by TRO as the influential group in physical-based topology; NY , PA , and MA are identified by TRO as influential groups in data-driven topology. The results prove that the influential groups in data topology are more effective in minimizing the generalization error. Table 3: MSE on TPT-48. Ignoring either the worst-case (IW-ERM) or influential (DRO) groups would yield compromised performance. Hop-1 Avg. Hop-2 Avg. Hop-3 Avg. Avg. ERM 1.084 1.265 1.975 1.426 IW-ERM 1.320 1.604 2.635 1.829 DRO 0.931 1.170 2.027 1.356 TRO 0.855 0.950 1.604 1.129 (2) Strong OOD resilience of TRO comes from the synergy of the worst-case and influential groups. To investigate which components contribute to the superior performance of TRO. We build a simple baseline based on ERM: we directly use the group importance acquired from the topology to weight training groups and the weights are fixed during the training. We name this baseline as importance weighted ERM (IW-ERM). We show the results of N(24) S(24) on TPT-48 in Tab. 3. The results of IW-ERM are inferior to ERM and DRO, possibly because IW-ERM merely considers influential Published as a conference paper at ICLR 2023 2019 USA FLOOD 2018 BOL FLOOD 2018 SOM FLOOD 2016 IND FLOOD IND Figure 6: Left: Location of the 11 flood events on Sen1Floods11. We use the event BOL for testing and other events for training. Right: Data-driven distributional topology on Sen1Floods11. (1) IND and NGA are identified by TRO as the most influential groups. A possible explanation is that both IND and NGA are aroused by heavy rainfall, the most prevalent disaster that causes floods. (2) GHA and KHM are identified by TRO as the least influential groups. A possible explanation is that both GHA and KHM are aroused edge cases such as dam collapse. The data-driven distributional topology is consistent with domain knowledge and facilitates the explainability of TRO. groups. We further show the group importance of DRO and TRO in Fig. 5. TRO significantly reduces the generalization risks by not only prioritizing the worst-case groups but also the influential ones. 5.3 SEMANTIC SEGMENTATION Datasets. Sen1Floods11 (Bonafilia et al., 2020) is a public dataset for flood mapping at the global scale. The dataset provides global coverage of 4,831 chips of 512 x 512 10m satellite images across 11 distinct flood events, covering 120,406 km2. Each image is associated with its pixel-wise label. Locations of the 11 flood events are shown in Fig. 6 (left). Flood events vary in boundary conditions, terrain, and other latent factors, posing significant OOD challenges to existing models in terms of reliability and explainability. Following Bonafilia et al. (2020), event BOL is held out for testing, and data of other events are split into training and validation sets with a random 80-20 split. ERM IRM REx SD DRO TRO (data) Val .489 .387 .484 .449 .480 .485 Test .430 .338 .357 .400 .433 .450 Table 4: Segmentation results (Io U) on Sen1Floods11. TRO yields better performance than other baselines on unseen flood events. 0.001 0.01 0.1 1 10 100 Lagrange Multiplier Io U on the Validation Set Figure 7: Ablation study on λ. Io U remains stable for a wide range of λ. Results. We show the results of Sen1Floods11 in Tab. 4. ERM achieves the highest Io U on the validation set while TRO achieves the highest Io U on the test set. The results prove that TRO yields better performance than other baselines on unseen flood events. Ablations study. (1) Data-driven distributional topology is consistent with domain knowledge. We visualize the distributional topology as well as group centrality in Fig. 6 (right). The learned distributional topology is consistent with domain knowledge, enhancing the explainability of TRO. (2) Ablation study on λ. We report Io U under different λ on Sen1Floods11 in Fig. 7. Io U remains stable for a wide range of λ, and λ = 0.01 yields the best performance. 6 CONCLUSION In this paper, we proposed a new principled optimization method that seamlessly integrates topological information to develop strong OOD resilience. Empirical results in a wide range of tasks including classification, regression, and semantic segmentation demonstrate the superior performance of our method over SOTA. Moreover, the data-driven distributional topology is consistent with domain knowledge and facilitates the explainability of our approach. Published as a conference paper at ICLR 2023 ACKNOWLEDGEMENTS This work is partially supported by National Science Foundation (NSF) CMMI-2039857, General University Research (GUR), and University of Delaware Research Foundation (UDRF). The authors would like to thank Kien X. Nguyen for helping with the experiments on Domain Bed. Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. ar Xiv preprint ar Xiv:1907.02893, 2019. Sara Beery, Grant Van Horn, and Pietro Perona. Recognition in terra incognita. In Proceedings of the European conference on computer vision (ECCV), pp. 456 473, 2018. Derrick Bonafilia, Beth Tellman, Tyler Anderson, and Erica Issenberg. Sen1floods11: A georeferenced dataset to train and test deep learning flood algorithms for sentinel-1. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, pp. 210 211, 2020. Stephen Boyd, Stephen P Boyd, and Lieven Vandenberghe. Convex optimization. Cambridge university press, 2004. Ronald R Coifman and Stéphane Lafon. Diffusion maps. Applied and computational harmonic analysis, 21(1):5 30, 2006. Liam Collins, Aryan Mokhtari, and Sanjay Shakkottai. Task-robust model-agnostic meta-learning. Advances in Neural Information Processing Systems, 33:18860 18871, 2020. Elliot Creager, Jörn-Henrik Jacobsen, and Richard Zemel. Environment inference for invariant learning. In International Conference on Machine Learning, pp. 2189 2200. PMLR, 2021. Erick Delage and Yinyu Ye. Distributionally robust optimization under moment uncertainty with application to data-driven problems. Operations research, 58(3):595 612, 2010. John C Duchi and Hongseok Namkoong. Learning models with uniform performance via distributionally robust optimization. The Annals of Statistics, 49(3):1378 1406, 2021. Chen Fang, Ye Xu, and Daniel N Rockmore. Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1657 1664, 2013. Linton C Freeman. A set of measures of centrality based on betweenness. Sociometry, pp. 35 41, 1977. Charlie Frogner, Sebastian Claici, Edward Chien, and Justin Solomon. Incorporating unlabeled data into distributionally robust learning. Journal of Machine Learning Research, 22(56):1 46, 2021. Ishaan Gulrajani and David Lopez-Paz. In search of lost domain generalization. In International Conference on Learning Representations, 2021. Weihua Hu, Gang Niu, Issei Sato, and Masashi Sugiyama. Does distributionally robust supervised learning give robust classifiers? In International Conference on Machine Learning, pp. 2029 2037. PMLR, 2018. Pang Wei Koh, Shiori Sagawa, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, et al. Wilds: A benchmark of in-the-wild distribution shifts. In International Conference on Machine Learning, pp. 5637 5664. PMLR, 2021. Masanori Koyama and Shoichiro Yamaguchi. Out-of-distribution generalization with maximal invariant predictor. 2020. Published as a conference paper at ICLR 2023 David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, and Aaron Courville. Out-of-distribution generalization via risk extrapolation (rex). In International Conference on Machine Learning, pp. 5815 5826. PMLR, 2021. William Leeb and Ronald Coifman. Hölder lipschitz norms and their duals on spaces with semigroups, with applications to earth mover s distance. Journal of Fourier Analysis and Applications, 22(4): 910 953, 2016. Daniel Levy, Yair Carmon, John C Duchi, and Aaron Sidford. Large-scale methods for distributionally robust optimization. Advances in Neural Information Processing Systems, 33:8847 8860, 2020. Da Li et al. Deeper, Broader and Artier Domain Generalization. In ICCV, 2017. Jiashuo Liu, Zheyuan Hu, Peng Cui, Bo Li, and Zheyan Shen. Heterogeneous risk minimization. In International Conference on Machine Learning, pp. 6804 6814. PMLR, 2021. Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. In International Conference on Machine Learning, pp. 4615 4625. PMLR, 2019. Hongseok Namkoong and John C Duchi. Stochastic gradient methods for distributionally robust optimization with f-divergences. Advances in neural information processing systems, 29, 2016. Arkadi Nemirovski, Anatoli Juditsky, Guanghui Lan, and Alexander Shapiro. Robust stochastic approximation approach to stochastic programming. SIAM Journal on optimization, 19(4):1574 1609, 2009. Mark EJ Newman. A measure of betweenness centrality based on random walks. Social networks, 27(1):39 54, 2005. Mohammad Pezeshki, Oumar Kaba, Yoshua Bengio, Aaron C Courville, Doina Precup, and Guillaume Lajoie. Gradient starvation: A learning proclivity in neural networks. Advances in Neural Information Processing Systems, 34, 2021. Qi Qian, Shenghuo Zhu, Jiasheng Tang, Rong Jin, Baigui Sun, and Hao Li. Robust optimization over multiple domains. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pp. 4739 4746, 2019. Alexander Robey, George J Pappas, and Hamed Hassani. Model-based domain generalization. Advances in Neural Information Processing Systems, 34:20210 20229, 2021. Elan Rosenfeld, Pradeep Ravikumar, and Andrej Risteski. The risks of invariant risk minimization. In International Conference on Learning Representations, volume 9, 2021. Shiori Sagawa, Pang Wei Koh, Tatsunori B Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations, 2019. Soroosh Shafieezadeh Abadeh, Viet Anh Nguyen, Daniel Kuhn, and Peyman M Mohajerin Esfahani. Wasserstein distributionally robust kalman filtering. Advances in Neural Information Processing Systems, 31, 2018. Shai Shalev-Shwartz and Yonatan Wexler. Minimizing the maximal loss: How and why. In International Conference on Machine Learning, pp. 793 801. PMLR, 2016. Yuge Shi, Jeffrey Seely, Philip Torr, N Siddharth, Awni Hannun, Nicolas Usunier, and Gabriel Synnaeve. Gradient matching for domain generalization. In International Conference on Learning Representations, 2022. Matthew Staib and Stefanie Jegelka. Distributionally robust optimization and generalization in kernel methods. Advances in Neural Information Processing Systems, 32, 2019. Yu Tian, Long Zhao, Xi Peng, and Dimitris Metaxas. Rethinking kernel methods for node representation learning on graphs. Advances in neural information processing systems, 32, 2019. Published as a conference paper at ICLR 2023 Alexander Y Tong, Guillaume Huguet, Amine Natik, Kincaid Mac Donald, Manik Kuchroo, Ronald Coifman, Guy Wolf, and Smita Krishnaswamy. Diffusion earth mover s distance and distribution embeddings. In International Conference on Machine Learning, pp. 10336 10346. PMLR, 2021. Vladimir Vapnik. Statistical learning theory, 1998. R Vose, S Applequist, M Squires, I Durre, MJ Menne, CN Williams Jr, C Fenimore, K Gleason, and D Arndt. Gridded 5km ghcn-daily temperature and precipitation dataset (nclimgrid) version 1. Information, NNCf E, editor. Maximum Temperature, Minimum Temperature, Average Temperature, and Precipitation, 2014. Zihao Xu, Guang-He Lee, Yuyang Wang, Hao Wang, et al. Graph-relational domain adaptation. In International Conference on Learning Representations, 2022. 7.1 RELATED WORK Distributionally Robust Optimization. In the context of distributionally robust optimization (DRO), Duchi & Namkoong (2021) and Shalev-Shwartz & Wexler (2016) argued that minimizing the maximal loss over a set of possible distributions can provide better generalization performance than minimizing the average loss. The robustness guarantee of DRO heavily relies on the quality of the uncertainty set which is typically constructed by moment constraints (Delage & Ye, 2010), f-divergence (Namkoong & Duchi, 2016) or Wasserstein distance (Shafieezadeh Abadeh et al., 2018). To avoid yielding overly pessimistic models, group DRO (Hu et al., 2018; Sagawa et al., 2019) is proposed to leverage pre-defined data groups to formulate the uncertainty set as the mixture of these groups. Although the uncertainty set of Group DRO is of a wider radius while not being too conservative, our preliminary results show that Group DRO recklessly prioritizes the worst-case groups that incur higher losses than others. Such worst-case groups are not necessarily the influential ones that are truly connected to unseen distributions; optimizing over the worst-case rather than influential groups would yield mediocre OOD generalization performance. Out-of-Distribution Generalization. The goal of OOD generalization is to generalize the model from source distributions to unseen target distributions. There are mainly two branches of methods to tackle OOD generalization: group-invariant learning (Arjovsky et al., 2019; Koyama & Yamaguchi, 2020; Liu et al., 2021) and distributionally robust optimization. The goal of group-invariant learning is to exploit the causally invariant correlations across multiple distributions. Invariant Risk Minimization (IRM) is one of the most representative methods which learns the optimal classifier across source distributions. However, recent work (Rosenfeld et al., 2021) shows that IRM methods can fail catastrophically unless the test data are sufficiently similar to the training distribution. 7.2 IMPLEMENTATION DETAILS In Sec. 3.1, for all hyperparameters such as the kernel scale σ2 and the maximum scale K, we use the default values from the official implementation1 of Tong et al. (2021). In Sec. 3.2, for learning rate of model parameters ηθ, we use default values from Xu et al. (2022) (DG-15/-60 and TPT-48) and Bonafilia et al. (2020) (Sen1Floods11). Therefore, we only tune the learning rate of the mixture distribution ηq and the dual variable λ. All results are reported over 3 random seed runs, which is consistent with Koh et al. (2021) and Shi et al. (2022). We select λ from {1e-3, 1e-2, 1e-1, 1, 10, 100} and select ηq from {1e-4, 1e-3, 1e-2, 1e-1, 1}. 7.3 ADDITIONAL RESULTS DG-15 and DG-60. We provide detailed classification results for each group. The results are shown in Fig. 8. We can see that, compared to other baselines, TRO significantly improves the generalization performance on groups that are far from the training groups, such as group 13 . We further visualize the decision boundary of DG-15 and DG-60 in Fig. 9 and Fig. 10, respectively. 1https://github.com/Krishnaswamy Lab/Diffusion EMD Published as a conference paper at ICLR 2023 ERM: 58.00% REx: 57.22% TRO: 67.56% IRM: 57.87% DRO: 43.22% Figure 8: Detailed results on DG-15. Our method outperforms ERM by 9.56% while other baselines are inferior to ERM. ERM: 58.00% IRM: 57.87% Ground Truth REx: 57.22% TRO: 67.56% DRO: 43.22% Class 0 Class 1 Figure 9: Visualization of decision boundary on DG-15. ERM: 76.02% IRM: 76.61% Ground Truth REx: 86.89% TRO: 89.19% DRO: 79.59% Class 0 Class 1 Figure 10: Visualization of decision boundary on DG-60. TRO can correctly classify the data of most groups even if training groups are only one-tenth of all groups. Published as a conference paper at ICLR 2023 Terra Incognita Art Photo Sketch Cartoon L100 L38 L43 L46 Caltech101 SUN09 Label Me VOC2007 Art Photo Sketch Data-driven topology on PACS Figure 11: Left: Image samples of Domain Bed. Right: the data-driven topology of PACS when Cartoon is the test group while the other three are training groups. We assume the reason why Art is the most influential group is that Art may contain more information than Photo and Sketch as Art is the combination of photos and various kinds of styles. Table 5: Accuracy (%) on PACS. Art : Art is the test group while the other three groups are training groups. In average accuracy, TRO outperforms the SOTA by 0.5% and outperforms ERM and DRO by 0.8% and 2.4%. Art Cartoon Photo Sketch Average ERM 88.1(0.1) 77.9(1.3) 97.8(0) 79.1(0.9) 85.7(0.5) Group DRO 86.4(0.3) 79.9(0.8) 98.0(0.3) 72.1(0.7) 84.1(0.4) CORAL (SOTA) 87.7(0.6) 79.2(1.1) 97.6(0) 79.4(0.7) 86.0(0.2) TRO (ours) 87.7(0.5) 82.1(0.5) 98.0(0.2) 78.2(1.9) 86.5(0.4) Table 6: Accuracy (%) on Terra. TRO achieves comparable results with the SOTA and outperforms ERM and DRO by 1.8% and 2.0%. L100 L38 L43 L46 Average ERM 50.8(1.8) 42.5(0.7) 57.9(0.6) 37.6(1.2) 47.2(0.4) Group DRO 47.2(1.6) 40.1(1.6) 57.6(0.9) 43.0(0.7) 47.0(0.3) MMD (SOTA) 52.2(5.8) 47.0(0.6) 57.8(1.3) 40.3(0.5) 49.3(1.4) TRO (ours) 53.3(2.4) 42.2(1.3) 59.0(0.8) 41.3(0.5) 49.0(0.6) Domain Bed. Following the instructions of the official implementation of Domain Bed Gulrajani & Lopez-Paz (2021), we have conducted experiments on PACS (Li et al., 2017), Terra (Beery et al., 2018), and VLCS (Fang et al., 2013). Image samples of the three datasets are shown in Fig. 11 (left). (1) PACS is one of the most popular dataset for out-of-distribution generalization. It consists of images from four groups: Art , Cartoon , Photo and Sketch . Results on PACS are shown in Tab. 5. Results of other baselines are from Appendix B.4 of Gulrajani & Lopez-Paz (2021). In average accuracy, TRO outperforms the SOTA by 0.5%. To further investigate the results, we visualize the learned topology in Fig. 11 (right). As observed, when Cartoon is the test group, the topology is a chain graph consisting of three nodes where Art is the most influential group. A possible explanation is that Art may contain more information than Photo and Sketch as Art can be viewed as the combination of photos and various kinds of styles. Even though the topology is so simple, it enables our method to significantly outperforms ERM and DRO by 0.8% and 2.4% on average. The results empirically demonstrate the strong explainability of our method when the number of training groups is quite limited, i.e., 3. We would like to point out that when the distributional shift across different groups is small (see explanation on the results of VLCS), the influential group may not exist and all groups share the same centrality. In this special case, TRO aims to strike a good balance between the average (ERM) risk and the worst-case (DRO) risk. Published as a conference paper at ICLR 2023 Table 7: Accuracy (%) on VLCS. The average accuracy of DRO and TRO is the same. We assume the reason is that the distributional shift across different groups is small (Li et al., 2017), and therefore the influential group may not exist and all groups share the same centrality. Caltech101 Label Me SUN09 VOC2007 Average ERM 97.6(1.0) 63.3(0.9) 72.2(0.5) 76.4(1.5) 77.4(0.3) Group DRO 97.7(0.4) 62.5(1.1) 70.1(0.7) 78.4(0.9) 77.2(0.6) DANN (SOTA) 98.5(0.2) 64.9(1.1) 73.1(0.7) 78.3(0.3) 78.7(0.3) TRO (ours) 96.9(0.2) 65.0(0.8) 71.3(0.9) 75.5(0.9) 77.2(0.5) (2) Terra consists of images of wild animals captured by camera traps under four locations. Results on Terra are shown in Tab. 6. Results of other baselines are from Appendix B.6 of Gulrajani & Lopez-Paz (2021). As observed, in average accuracy, TRO achieves comparable results with the SOTA and outperforms ERM and DRO by 1.8% and 2.0%. (3) Results on VLCS are shown in Tab. 7. Results of other baselines are from Appendix B.3 of Gulrajani & Lopez-Paz (2021). The average accuracy of DRO and TRO is the same. We assume the reason is that the distributional shift across different groups is small (Li et al., 2017), and therefore the influential group may not exist and all groups share the same centrality. In this special case, TRO aims to strike a good balance between the average (ERM) risk and the worst-case (DRO) risk. The images of VLCS are all photos and the distributional shift is not as significant as PACS (e.g., Photo vs. Sketch). As stated in Sec. 2.1 of Li et al. (2017), despite the famous analysis of dataset bias that motivated the creation of the VLCS benchmark, it was later shown that the domain shift is much smaller with recent deep features , and PACS (Li et al., 2017) was proposed to address this limitation. 7.4 PROOF OF THEOREM 1 Proof. By using the property of convex loss functions, we can obtain: max q m R (θT , q) min θ Θ R (θ, q T ) 1 T max q ,θ Θ t=1 R θt, q R θ, qt ) R θt, q R θ, qt =R θt, q R θt, qt + R θt, qt R θ, qt q qt , q R θt, qt + θt θ , θR θt, qt . By rearranging the terms in the above equation, we obtain: t=1 R θt, q R θ, qt ) θt θ , ˆgt θ # q qt , ˆgt q # q, q R θt, qt ˆgt q # θ, ˆgt θ θR θt, qt # Following Collins et al. (2020), we will derive the combined bound by bounding the expectation of each term in the above equation. For the first term, by utilizing the telescoping sum, we can obtain: θt θ , ˆgt θ # θt θ 2 2 + ηθ ˆgt θ 2 2 1 θt ηθˆgt θ θ 2 2 2R2 Θ ηθ + ηθ ˆgt θ 2 2 2R2 Θ ηθ + ηθT ˆG2 θ 2 . Published as a conference paper at ICLR 2023 Similarly, for the second term: q qt , ˆgt q # ηq + ηq T ˆG2 q 2 . The third term and the last term are bounded by T σq and RΘ T σθ, respectively. To this end, we can derive the overall bound as: E max q m R (θT , q) min θ Θ R (θ, q T ) 2R2 Θ ηθT + ηθ ˆG2 θ 2 + 2 ηq T + ηq ˆG2 q 2 + RΘ σθ The above bound can be minimized by setting the step sizes ηθ = 2RΘ/ ˆGθ 7.5 PROOF OF THEOREM 2 Proof. Inspired by Qian et al. (2019) and Collins et al. (2020), we utilize the property of M-smooth to start the proof: i=1 qt iℓi θt+1 # i=1 qt iℓi θt ηθ η2 θM gt θ 2 + η2 θMσ2 θ 2 . We rearrange these terms by: ηθ η2 θM i=1 qt iℓi θt i=1 qt iℓi θt+1 # + η2 θMσ2 θ 2 i=1 qt iℓi θt i=1 qt+1 i ℓi θt+1 # i=1 qt+1 i ℓi θt+1 i=1 qt iℓi θt+1 # + η2 θMσ2 θ 2 . The second term of the above equation can be bounded by: i=1 qt+1 i ℓi θt+1 i=1 qt iℓi θt+1 # qt+1 i qt i ℓi θt+1 # " qt+1 qt 2 ℓi θt+1 1/2 # ηq n ˆB ˆGq. By using the Law of Iterated Expectations, we can obtain: ηθ η2 θM t=1 E h gt θ 2i i=1 q1 i ℓi θ1 # i=1 q T +1 i ℓi θT +1 # + 2Tηq n ˆB ˆGq + TMη2 θσ2 θ 2 R θ1, q1 + ˆB + 2Tηq n ˆB ˆGq + Tη2 θMσ2 θ 2 . Next, we investigate the convergence of q: E R θt, q R θt, qt = E 1 q qt 2 2 + (ηq)2 ˆgt q 2 2 q qt + ηqˆgt q 2 2 q qt 2 2 + (ηq)2 ˆgt q 2 2 q qt+1 2 2 q qt 2 2 + (ηq)2 ˆG2 q q qt+1 2 2 Published as a conference paper at ICLR 2023 By aggregating the difference at all time steps, we obtain: t=1 E R θt, q R θt, qt 1 2ηq E h q qt 2 2 i 1 2ηq E h q qt+1 2 2 = 1 2ηq E h q q1 2 2 ηq + ηq T ˆG2 q 2 . Since the above equation holds for all q m, we maximize the right hand side over q m: t=1 E R θt, qt max q m [R (θT , q)] 1 ηq T + ηq ˆG2 q 2 Eqs. 9 and 10 show that TRO converges in expectation to an (ϵ, δ)-stationary point of R in O(1/ϵ4) stochastic gradient evaluations.