# surrogate_gap_minimization_improves_sharpnessaware_training__fc038aea.pdf Published as a conference paper at ICLR 2022 SURROGATE GAP MINIMIZATION IMPROVES SHARPNESS-AWARE TRAINING Juntang Zhuang1 j.zhuang@yale.edu Boqing Gong2, Liangzhe Yuan2, Yin Cui2, Hartwig Adam2 {bgong, lzyuan, yincui, hadam}@google.com Nicha C. Dvornek1, Sekhar Tatikonda1, James S. Duncan1 {nicha.dvornek, sekhar.tatikonda, james.duncan}@yale.edu liuti@google.com 1 Yale University, 2 Google Research The recently proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. However, we show that both sharp and flat minima can have a low perturbed loss, implying that SAM does not always prefer flat minima. Instead, we define a surrogate gap, a measure equivalent to the dominant eigenvalue of Hessian at a local minimum when the radius of neighborhood (to derive the perturbed loss) is small. The surrogate gap is easy to compute and feasible for direct minimization during training. Based on the above observations, we propose Surrogate Gap Guided Sharpness-Aware Minimization (GSAM), a novel improvement over SAM with negligible computation overhead. Conceptually, GSAM consists of two steps: 1) a gradient descent like SAM to minimize the perturbed loss, and 2) an ascent step in the orthogonal direction (after gradient decomposition) to minimize the surrogate gap and yet not affect the perturbed loss. GSAM seeks a region with both small loss (by step 1) and low sharpness (by step 2), giving rise to a model with high generalization capabilities. Theoretically, we show the convergence of GSAM and provably better generalization than SAM. Empirically, GSAM consistently improves generalization (e.g., +3.2% over SAM and +5.4% over Adam W on Image Net top-1 accuracy for Vi T-B/32). Code is released at https://sites.google.com/view/gsam-iclr22/home. 1 INTRODUCTION Modern neural networks are typically highly over-parameterized and easy to overfit to training data, yet the generalization performances on unseen data (test set) often suffer a gap from the training performance (Zhang et al., 2017a). Many studies try to understand the generalization of machine learning models, including the Bayesian perspective (Mc Allester, 1999; Neyshabur et al., 2017), the information perspective (Liang et al., 2019), the loss surface geometry perspective (Hochreiter & Schmidhuber, 1995; Jiang et al., 2019) and the kernel perspective (Jacot et al., 2018; Wei et al., 2019). Besides analyzing the properties of a model after training, some works study the influence of training and the optimization process, such as the implicit regularization of stochastic gradient descent (SGD) (Bottou, 2010; Zhou et al., 2020), the learning rate s regularization effect (Li et al., 2019), and the influence of the batch size (Keskar et al., 2016). These studies have led to various modifications to the training process to improve generalization. Keskar & Socher (2017) proposed to use Adam in early training phases for fast convergence and then switch to SGD in late phases for better generalization. Izmailov et al. (2018) proposed to average weights to achieve a wider local minimum, which is expected to generalize better than sharp minima. A similar idea was later used in Lookahead (Zhang et al., 2019). Entropy-SGD (Chaudhari Work was done during an internship at Google Published as a conference paper at ICLR 2022 et al., 2019) derived the gradient of local entropy to avoid solutions in sharp valleys. Entropy-SGD has a nested Langevin iteration, inducing much higher computation costs than vanilla training. The recently proposed Sharpness-Aware Minimization (SAM) (Foret et al., 2020) is a generic training scheme that improves generalization and has been shown especially effective for Vision Transformers (Dosovitskiy et al., 2020) when large-scale pre-training is unavailable (Chen et al., 2021). Suppose vanilla training minimizes loss f(w) (e.g., the cross-entropy loss for classification), where w is the parameter. SAM minimizes a perturbed loss defined as fp(w) max||δ|| ρ f(w + δ), which is the maximum loss within radius ρ centered at the model parameter w. Intuitively, vanilla training seeks a single point with a low loss, while SAM searches for a neighborhood within which the maximum loss is low. However, we show that a low perturbed loss fp could appear in both flat and sharp minima, implying that only minimizing fp is not always sharpness-aware. Although the perturbed loss fp(w) might disagree with sharpness, we find a surrogate gap defined as h(w) fp(w) f(w) agrees with sharpness Lemma 3.3 shows that the surrogate gap h is an equivalent measure of the dominant eigenvalue of Hessian at a local minimum. Inspired by this observation, we propose the Surrogate Gap Guided Sharpness Aware Minimization (GSAM) which jointly minimizes the perturbed loss fp and the surrogate gap h: a low perturbed loss fp indicates a low training loss within the neighborhood, and a small surrogate gap h avoids solutions in sharp valleys and hence narrows the generalization gap between training and test performances (Thm. 5.3). When both criteria are satisfied, we find a generalizable model with good performances. GSAM consists of two steps for each update: 1) descend gradient fp(w) to minimize the perturbed loss fp (this step is exactly the same as SAM), and 2) decompose gradient f(w) of the original loss f(w) into components that are parallel and orthogonal to fp(w), i.e., f(w) = f(w) + f(w), and perform an ascent step in f(w) to minimize the surrogate gap h(w). Note that this ascent step does not change the perturbed loss fp because f (w) fp(w) by construction. We summarize our contribution as follows: We define surrogate gap, which measures the sharpness at local minima and is easy to compute. We propose the GSAM method to improve the generalization of neural networks. GSAM is widely applicable and incurs negligible computation overhead compared to SAM. We demonstrate the convergence of GSAM and its provably better generalization than SAM. We empirically validate GSAM over image classification tasks with various neural architectures, including Res Nets (He et al., 2016), Vision Transformers (Dosovitskiy et al., 2020), and MLP-Mixers (Tolstikhin et al., 2021). 2 PRELIMINARIES 2.1 NOTATIONS f(w): A loss function f with parameter w Rk, where k is the parameter dimension. ρt R: A scalar value controlling the amplitude of perturbation at step t. ϵ R: A small positive constant (to avoid division by 0, ϵ = 10 12 by default). wadv t wt + ρt f(wt) || f(wt)||+ϵ: The solution to max||w wt|| ρt f(w ) when ρt is small. fp(wt) max||δ|| ρt f(wt + δ) f(wadv t ): The perturbed loss induced by f(wt). For each wt, fp(wt) returns the worst possible loss f within a ball of radius ρt centered at wt. When ρt is small, by Taylor expansion, the solution to the maximization problem is equivalent to a gradient ascent from wt to wadv t . h(w) fp(w) f(w): The surrogate gap defined as the difference between fp(w) and f(w). ηt R: Learning rate at step t. α R: A constant value that controls the scaled learning rate of the ascent step in GSAM. g(t), g(t) p Rk: At the t-th step, the noisy observation of the gradients f(wt), fp(wt) of the original loss and perturbed loss, respectively. Published as a conference paper at ICLR 2022 Figure 1: Consider original loss f (solid line), perturbed loss fp max||δ|| ρ f(w+δ) (dashed line), and surrogate gap h(w) fp(w) f(w). Intuitively, fp is approximately a max-pooled version of f with a pooling kernel of width 2ρ, and SAM minimizes fp. From left to right are the local minima centered at w1, w2, w3, and the valleys become flatter. Since fp(w1) = fp(w3) < fp(w2), SAM prefers w1 and w3 to w2. However, a low fp could appear in both sharp (w1) and flat (w3) minima, so fp might disagree with sharpness. On the contrary, a smaller surrogate gap h indicates a flatter loss surface (Lemma 3.3). From w1 to w3, the loss surface is flatter, and h is smaller. f(wt) = f (wt) + f (wt): Decompose f(wt) into parallel component f (wt) and vertical component f (wt) by projection f(wt) onto fp(wt). 2.2 SHARPNESS-AWARE MINIMIZATION Conventional optimization of neural networks typically minimizes the training loss f(w) by gradient descent w.r.t. f(w) and searches for a single point w with a low loss. However, this vanilla training often falls into a sharp valley of the loss surface, resulting in inferior generalization performance (Chaudhari et al., 2019). Instead of searching for a single point solution, SAM seeks a region with low losses so that small perturbation to the model weights does not cause significant performance degradation. SAM formulates the problem as: minw fp(w) where fp(w) max||δ|| ρ f(w + δ) (1) where ρ is a predefined constant controlling the radius of a neighborhood. This perturbed loss fp induced by f(w) is the maximum loss within the neighborhood. When the perturbed loss is minimized, the neighborhood corresponds to low losses (below the perturbed loss). For a small ρ, using Taylor expansion around w, the inner maximization in Eq. 1 turns into a linear constrained optimization with solution arg max||δ|| ρ f(w + δ) = arg max||δ|| ρ f(w) + δ f(w) + O(ρ2) = ρ f(w) || f(w)|| (2) As a result, the optimization problem of SAM reduces to minw fp(w) minw f(wadv) where wadv w + ρ f(w) || f(w)|| + ϵ (3) where ϵ is a scalar (default: 1e-12) to avoid division by 0, and wadv is the perturbed weight with the highest loss within the neighborhood. Equivalently, SAM seeks a solution on the surface of the perturbed loss fp(w) rather than the original loss f(w) (Foret et al., 2020). 3 THE SURROGATE GAP MEASURES THE SHARPNESS AT A LOCAL MINIMUM 3.1 THE PERTURBED LOSS IS NOT ALWAYS SHARPNESS-AWARE Despite that SAM searches for a region of low losses, we show that a solution by SAM is not guaranteed to be flat. Throughout this paper we measure the sharpness at a local minimum of loss f(w) by the dominant eigenvalue σmax (eigenvalue with the largest absolute value) of Hessian. For simplicity, we do not consider the influence of reparameterization on the geometry of loss surfaces, which is thoroughly discussed in (Laurent & Massart, 2000; Kwon et al., 2021). Published as a conference paper at ICLR 2022 Figure 2: f is decomposed into parallel and vertical ( f ) components by projection onto fp. f GSAM = fp α f Algorithm 1 GSAM Algorithm For t = 1 to T 0) ρt schedule: ρt = ρmin + (ρmax ρmin)(lr lrmin) lrmax lrmin 1a) wt = ρt || f (t)||+ϵ 1b) wadv t = wt + wt 2) Get f (t) p by back-propagation at wadv t . 3) f (t) = f (t) + f (t) Decompose f (t) into compo- nents that are parallel and orthogonal to f (t) p . 4) Update weights: Vanilla wt+1 = wt ηt f (t) SAM wt+1 = wt ηt f (t) p GSAM wt+1 = wt ηt( f (t) p α f (t) ) Lemma 3.1. For some fixed ρ, consider two local minima w1 and w2, fp(w1) fp(w2) = σmax(w1) σmax(w2), where σmax is the dominant eigenvalue of the Hessian. We leave the proof to Appendix. Fig. 1 illustrates Lemma 3.1 with an example. Consider three local minima denoted as w1 to w3, and suppose the corresponding loss surfaces are flatter from w1 to w3. For some fixed ρ, we plot the perturbed loss fp and surrogate gap h fp f around each solution. Comparing w2 with w3: Suppose their vanilla losses are equal, f(w2) = f(w3), then fp(w2) > fp(w3) because the loss surface is flatter around w3, implying that SAM will prefer w3 to w2. Comparing w1 and w2: fp(w1) < fp(w2), and SAM will favor w1 over w2 because it only cares about the perturbed loss fp, even though the loss surface is sharper around w1 than w2. 3.2 THE SURROGATE GAP AGREES WITH SHARPNESS We introduce the surrogate gap that agrees with sharpness, defined as: h(w) max||δ|| ρ f(w + δ) f(w) f(wadv) f(w) (4) Intuitively, the surrogate gap represents the difference between the maximum loss within the neighborhood and the loss at the center point. The surrogate gap has the following properties. Lemma 3.2. Suppose the perturbation amplitude ρ is sufficiently small, then the approximation to the surrogate gap in Eq. 4 is always non-negative, h(w) f(wadv) f(w) 0, w. Lemma 3.3. For a local minimum w , consider the dominate eigenvalue σmax of the Hessian of loss f as a measure of sharpness. Considering the neighborhood centered at w with a small radius ρ, the surrogate gap h(w ) is an equivalent measure of the sharpness: σmax 2h(w )/ρ2. The proof is in Appendix. Lemma 3.2 tells that the surrogate gap is non-negative, and Lemma 3.3 shows that the loss surface is flatter as h gets closer to 0. The two lemmas together indicate that we can find a region with a flat loss surface by minimizing the surrogate gap h(w). 4 SURROGATE GAP GUIDED SHARPNESS-AWARE MINIMIZATION 4.1 GENERAL IDEA: SIMULTANEOUSLY MINIMIZE THE PERTURBED LOSS AND SURROGATE GAP Inspired by the analysis in Section 3, we propose Surrogate Gap Guided Sharpness-Aware Minimzation (GSAM) to simultaneously minimize two objectives, the perturbed loss fp and the surrogate gap h: minw fp(w), h(w) (5) Intuitively, by minimizng fp we search for a region with a low perturbed loss similar to SAM, and by minimizing h we search for a local minimum with a flat surface. A low perturbed loss implies Published as a conference paper at ICLR 2022 low training losses within the neighborhood, and a flat loss surface reduces the generalization gap between training and test performances (Chaudhari et al., 2019). When both are minimized, the solution gives rise to high accuracy and good generalization. Potential caveat in optimization It is tempting and yet sub-optimal to combine the objectives in Eq. 5 to arrive at minw fp(w)+λh(w), where λ is some positive scalar. One caveat when solving this weighted combination is the potential conflict between the gradients of the two terms, i.e., fp(w) and h(w). We illustrate this conflict by Fig. 2, where h(w) = fp(w) f(w) (the grey dashed arrow) has a negative inner product with fp(w) and f(w). Hence, the gradient descent for the surrogate gap could potentially increase the loss fp, harming the model s performance. We empirically validate this argument in Sec. 6.4. 4.2 GRADIENT DECOMPOSITION AND ASCENT FOR THE MULTI-OBJECTIVE OPTIMIZATION Our primary goal is to minimize fp because otherwise a flat solution of high loss is meaningless, and the minimization of h should not increase fp. We propose to decompose f(wt) and h into components that are parallel and orthogonal to fp(wt), respectively (see Fig. 2): f(wt) = f (wt) + f (wt) h(wt) = h (wt) + h (wt) (6) h (wt) = f (wt) The key is that updating in the direction of h (wt) does not change the value of the perturbed loss fp(wt) because h fp by construction. Therefore, we propose to perform a descent step in the h (wt) direction, which is equivalent to an ascent step in the f (wt) direction (because h = f by the definition of h), and achieve two goals simultaneously it keeps the value of fp(wt) intact and meanwhile decreases the surrogate gap h(wt) = fp(wt) f(wt) (by increasing f(wt) and not affect fp(wt)). The full GSAM Algorithm is shown in Algo. 1 and Fig. 2, where g(t), g(t) p are noisy observations of f(wt) and fp(wt), respectively, and g(t) , g(t) are noisy observations of f (wt) and f (wt), respectively, by projecting g(t) onto g(t) p . We introduce a constant α to scale the stepsize of the ascent step. Steps 1) to 2) are the same as SAM: At current point wt, step 1) takes a gradient ascent to wadv t followed by step 2) evaluating the gradient g(t) p at wadv t . Step 3) projects g(t) onto g(t) p , which requires negligible computation compared to the forward and backward passes. In step 4), ηtg(t) p is the same as in SAM and minimizes the perturbed loss fp(wt) with gradient descent, and αηtg(t) performs an ascent step in the orthogonal direction of g(t) p to minimize the surrogate gap h(wt) ( equivalently increase f(wt) and keep fp(wt) intact). In coding, GSAM feeds the surrogate gradient f GSAM t g(t) p αg(t) to first-order gradient optimizers such as SGD and Adam. The ascent step along g(t) does not harm convergence SAM demonstrates that minimizing fp makes the network generalize better than minimizing f. Even though our ascent step along g(t) increases f(w), it does not affect fp(w), so GSAM still decreases the perturbed loss fp in a way similar to SAM. In Thm. 5.1, we formally prove the convergence of GSAM. In Sec. 6 and Appendix C, we empirically validate that the loss decreases and accuracy increases with training. Illustration with a toy example We demonstrate different algorithms by a numerical toy example shown in Fig. 3. The trajectory of GSAM is closer to the ridge and tends to find a flat minimum. Intuitively, since the loss surface is smoother along the ridge than in sharp local minima, the surrogate gap h(w) is small near the ridge, and the ascent step in GSAM minimizes h to pushes the trajectory closer to the ridge. More concretely, f(wt) points to a sharp local solution and deviates from the ridge; in contrast, wadv t is closer to the ridge and f(wadv t ) is closer to the ridge descent direction than f(wt). Note that f GSAM t and f(wt) always lie at different sides of fp(wt) by construction (see Fig. 2), hence f GSAM t pushes the trajectory closer to the ridge than fp(wt) does. The trajectory of GSAM is like descent along the ridge and tends to find flat minima. Published as a conference paper at ICLR 2022 Figure 3: Consider the loss surface with a few sharp local minima. Left: Overview of the procedures of SGD, SAM and GSAM. SGD takes a descent step at wt using f(wt) (orange), which points to a sharp local minima. SAM first performs gradient ascent in the direction of f(wt) to reach wadv t with a higher loss, followed by descent with gradient f(wadv t ) (green) at the perturbed weight. Based on f(wt) and f(wadv t ), GSAM updates in a new direction (red) that points to a flatter region. Right: Trajectories by different methods. SGD and SAM fall into different sharp local minima, while GSAM reaches a flat region. A video is in the supplement for better visualization. 5 THEORETICAL PROPERTIES OF GSAM 5.1 CONVERGENCE DURING TRAINING Theorem 5.1. Consider a non-convex function f(w) with Lipschitz-smooth constant L and lower bound fmin. Suppose we can access a noisy, bounded observation g(t) (||g(t)||2 G, t) of the true gradient f(wt) at the t-th step. For some constant α, with learning rate ηt = η0/ t, and perturbation amplitude ρt proportional to the learning rate, e.g., ρt = ρ0/ t=1 E fp(wt) 2 2 C1 + C2 log T t=1 E f(wt) 2 2 C3 + C4 log T where C1, C2, C3, C4 are some constants. Thm. 5.1 implies both fp and f converge in GSAM at rate O(log T/ T) for non-convex stochastic optimization, matching the convergence rate of first-order gradient optimizers like Adam. 5.2 GENERALIZATION OF GSAM In this section, we show the surrogate gap in GSAM is provably lower than SAM s, so GSAM is expected to find a smoother minimum with better generalization. Theorem 5.2 (PAC-Bayesian Theorem (Mc Allester, 2003)). Suppose the training set has m elements drawn i.i.d. from the true distribution, and denote the loss on the training set as bf(w) = 1 m Pm i=1 f(w, xi), where we use xi to denote the (input, target) pair of the i-th element. Let w be learned from the training set. Suppose w is drawn from posterior distribution Q. Denote the prior distribution (independent of training) as P, then Ew QExf(w, x) Ew Q bf(w) + 4 r KL(Q||P) + log 2m /m with probability at least 1 a Corollary 5.2.1. Suppose perturbation δ is drawn from distribution δ N(0, b2Ik), δ Rk, k is the dimension of w, then with probability at least 1 a h 1 e ρ Ew QExf(w, x) bh + C + 4 r KL(Q||P) + log 2m bh max||δ||2 ρ bf(w + δ) bf(w) = 1 h max||δ||2 ρ f(w + δ, xi) f(w, xi) i (8) Published as a conference paper at ICLR 2022 where C = bf(w) is the empirical training loss, and bh is the surrogate gap evaluated on the training set. Corollary 5.2.1 implies that minimizing bh (right hand side of Eq. 7) is expected to achieve a tighter upper bound of the generalization performance (left hand side of Eq. 7). The third term on the right of Eq. 7 is typically hard to analyze and often simplified to L2 regularization (Foret et al., 2020). Note that fp = C + bh only holds when ρtrain (the perturbation amplitude specified by users during training) equals ρtrue (the ground truth value determined by underlying data distribution); when ρtrain = ρtrue, min(fp,bh) is more effective than min(fp) in terms of minimizing generalization loss. A detailed discussion is in Appendix A.7. Theorem 5.3 (Unlike SAM, GSAM decreases the surrogate gap). Under the assumption in Thm. 5.1, Thm. 5.2 and Corollary 5.2.1, we assume the Hessian has a lower-bound |σ|min on the absolute value of eigenvalue, and the variance of noisy observation g(t) is lower-bounded by c2. The surrogate gap h can be minimized by the ascent step along the orthogonal direction g(t) . During training we minimize the sample estimate of h. We use bht to denote the amount that the ascent step in GSAM decreases bh for the t-th step. Compared to SAM, the proposed method generates a total decrease in surrogate gap PT t=1 bht, which is bounded by αc2ρ2 0η0|σ|2 min G2 lim T t=1 bht 2.7αL2η0ρ2 0 (9) We provide proof in the appendix. The lower-bound of PT t=1 bht indicates that GSAM achieves a provably non-trivial decrease in the surrogate gap. Combined with Corollary 5.2.1, GSAM provably improves the generalization performance over SAM. 6 EXPERIMENTS 6.1 GSAM IMPROVES TEST PERFORMANCE ON VARIOUS MODEL ARCHITECTURES We conduct experiments with Res Nets (He et al., 2016), Vision Transformers (Vi Ts) (Dosovitskiy et al., 2020) and MLP-Mixers (Tolstikhin et al., 2021). Following the settings by Chen et al. (2021), we train on the Image Net-1k (Deng et al., 2009) training set using the Inception-style (Szegedy et al., 2015) pre-processing without extra training data or strong augmentation. For all models, we search for the best learning rate and weight decay for vanilla training, and then use the same values for the experiments with SAM and GSAM. For Res Nets, we search for ρ from 0.01 to 0.05 with a stepsize 0.01. For Vi Ts and Mixers, we search for ρ from 0.05 to 0.6 with a stepsize 0.05. In GSAM, we search for α in {0.01, 0.02, 0.03} for Res Nets and α in {0.1, 0.2, 0.3} for Vi Ts and Mixers. Considering that each step in SAM and GSAM requires twice the computation of vanilla training, we experiment with the vanilla training for twice the epochs of SAM and GSAM, but we observe no significant improvements from the longer training (Table 5 in appendix). We summarize the best hyper-parameters for each model in Appendix B. We report the performances on Image Net (Deng et al., 2009), Image Net-v2 (Recht et al., 2019) and Image Net-Real (Beyer et al., 2020) in Table 1. GSAM consistently improves over SAM and vanilla training (with SGD or Adam W): on Vi T-B/32, GSAM achieves +5.4% improvement over Adam W and +3.2% over SAM in top-1 accuracy; on Mixer-B/32, GSAM achieves +11.1% over Adam W and +1.2% over SAM. We ignore the standard deviation since it is typically negligible (< 0.1%) compared to the improvements. We also test the generalization performance on out-of-distribution data (Image Net-R and Image Net-C), and the observation is consistent with that on Image Net, e.g., +5.1% on Image Net-R and +5.9% on Image Net-C for Mixer-B/32. 6.2 GSAM FINDS A MINIMUM WHOSE HESSIAN HAS SMALL DOMINANT EIGENVALUES Lemma 3.3 indicates that the surrogate gap h is an equivalent measure of the dominant eigenvalue of the Hessian, and minimizing h equivalently searches for a flat minimum. We empirically validate this in Fig. 4. As shown in the left subfigure, for some fixed ρ, increasing α decreases the dominant value and improves generalization (test accuracy). In the middle subfigure, we plot the dominant Published as a conference paper at ICLR 2022 Table 1: Top-1 Accuracy (%) on Image Net datasets for Res Nets, Vi Ts and MLP-Mixers trained with Vanilla SGD or Adam W, SAM, and GSAM optimizers. Model Training Image Net-v1 Image Net-Real Image Net-V2 Image Net-R Image Net-C Res Net Vanilla (SGD) 76.0 82.4 63.6 22.2 44.6 SAM 76.9 83.3 64.4 23.8 46.5 GSAM 77.2 83.9 64.6 23.6 47.6 Vanilla (SGD) 77.8 83.9 65.3 24.4 48.5 SAM 78.6 84.8 66.7 25.9 51.3 GSAM 78.9 85.2 67.3 26.3 51.8 Vanilla (SGD) 78.5 84.2 66.3 25.3 50.0 SAM 79.3 84.9 67.3 25.7 52.2 GSAM 80.0 85.9 68.6 27.3 54.1 Vision Transformer Vanilla (Adam W) 68.4 75.2 54.3 19.0 43.3 SAM 70.5 77.5 56.9 21.4 46.2 GSAM 73.8 80.4 60.4 22.5 48.2 Vanilla (Adam W) 74.4 80.4 61.7 20.0 46.5 SAM 78.1 84.1 65.6 24.7 53.0 GSAM 79.5 85.3 67.3 25.3 53.3 Vanilla (Adam W) 71.4 77.5 57.5 23.4 44.0 SAM 73.6 80.3 60.0 24.0 50.7 GSAM 76.8 82.7 63.0 25.1 51.7 Vanilla (Adam W) 74.6 79.8 61.3 20.1 46.6 SAM 79.9 85.2 67.5 26.4 56.5 GSAM 81.0 86.5 69.2 27.1 55.7 MLP-Mixer Vanilla (Adam W) 63.9 70.3 49.5 16.9 35.2 SAM 66.7 73.8 52.4 18.6 39.3 GSAM 68.6 75.8 55.0 22.6 44.6 Vanilla (Adam W) 68.8 75.1 54.8 15.9 35.6 SAM 72.9 79.8 58.9 20.1 42.0 GSAM 75.0 81.7 61.9 23.7 48.5 Vanilla (Adam W) 70.2 76.2 56.1 15.4 34.6 SAM 75.9 82.5 62.3 20.5 42.4 GSAM 76.8 83.4 64.0 24.6 47.8 Vanilla (Adam W) 62.5 68.1 47.6 14.6 33.8 SAM 72.4 79.0 58.0 22.8 46.2 GSAM 73.6 80.2 59.9 27.9 52.1 Vanilla (Adam W) 66.4 72.1 50.8 14.5 33.8 SAM 77.4 83.5 63.9 24.7 48.8 GSAM 77.8 84.0 64.9 28.3 54.4 0.05 0.10 0.15 0.20 0.25 0.30 Top-1 Accuracy (%) Top-1 Accuracy = 0.05 = 0.15 = 0.2 0.05 0.10 0.15 0.20 0.25 0.30 Estimation of dominant eigenvalue = 0.05 = 0.15 = 0.2 0.05 0.10 0.15 0.20 0.25 0.30 log( max(H)) Measured dominant eigenvalue = 0.05 = 0.15 = 0.2 Figure 4: Influence of ρ (set as constant for ease of comparison, other experiments use decayed ρt schedule) and α on the training of Vi T-B/32. Left: Top-1 accuracy on Image Net. Middle: Estimation of the dominant eigenvalues from the surrogate gap, σmax 2h/ρ2. Right: Dominant eigenvalues of the Hessian calculated via the power iteration. Middle and right figures match in the trend of curves, validating that the surrogate gap can be viewed as a proxy of the dominant eigenvalue of Hessian. Published as a conference paper at ICLR 2022 Top-1 Accuracy (%) Image Net accuracy Vanilla Entropy SAM SAM+ascent ASAM ASAM+ascent Top-1 Accuracy (%) Image Net-Real accuracy Vanilla Entropy SAM SAM+ascent ASAM ASAM+ascent Top-1 Accuracy (%) Image Net-v2 accuracy Vanilla Entropy SAM SAM+ascent ASAM ASAM+ascent Figure 5: Top-1 accuracy of Mixer-S/32 trained with different methods. +ascent represents applying the ascent step in Algo. 1 to an optimizer. Note that our GSAM is described as SAM+ascent(=GSAM) for consistency. Table 2: Results (%) of GSAM and min(fp + λh) on Vi T-B/32 Dataset min(fp + λh) GSAM Image Net 75.4 76.8 Image Net-Real 81.1 82.7 Image Net-v2 60.9 63.0 Image Net-R 23.9 25.1 Table 3: Transfer learning results (top-1 accuracy, %) Vi T-B/16 Vi T-S/16 Vanilla SAM GSAM Vanilla SAM GSAM Cifar10 98.1 98.6 98.8 97.6 98.2 98.4 Cifar100 87.6 89.1 89.7 85.7 87.6 88.1 Flowers 88.5 91.8 91.2 86.4 91.5 90.3 Pets 91.9 93.1 94.4 90.4 92.9 93.5 mean 91.5 93.2 93.5 90.0 92.6 92.6 eigenvalues estimated by the surrogate gap, σmax 2h/ρ2 (Lemma 3.3). In the right subfigure, we directly calculate the dominant eigenvalues using the power-iteration (Mises & Pollaczek-Geiringer, 1929). The estimated dominant eigenvalues (middle) match the real eigenvalues σmax (right) in terms of the trend that σmax decreases with α and ρ. Note that the surrogate gap h is derived over the whole training set, while the measured eigenvalues are over a subset to save computation. These results show that the ascent step in GSAM minimizes the dominant eigenvalue by minimizing the surrogate loss, validating Thm 5.3. 6.3 COMPARISON WITH METHODS IN THE LITERATURE Section 6.1 compares GSAM to SAM and vanilla training. In this subsection, we further compare GSAM against Entropy-SGD (Chaudhari et al., 2019) and Adaptive-SAM (ASAM) (Kwon et al., 2021), which are designed to improve generalization. Note that Entropy-SGD uses SGD in the inner Langevin iteration and can be combined with other base optimizers such as Adam W as the outer loop. For Entropy-SGD, we find the hyper-parameter scope from 0.0 and 0.9, and search for the inner-loop iteration number between 1 and 14. For ASAM, we search for ρ between 1 and 7 (10 larger than in SAM) as recommended by the ASAM authors. Note that the only difference between ASAM and SAM is the derivation of the perturbation, so both can be combined with the proposed ascent step. As shown in Fig. 5, the proposed ascent step increases test accuracy when combined with both SAM and ASAM and outperforms Entropy-SGD and vanilla training. 6.4 ADDITIONAL STUDIES GSAM outperforms a weighted combination of the perturbed loss and surrogate gap With an example in Fig. 2, we demonstrate that directly minimizing fp(w) + λh(w) as discussed in Sec. 4.1 is sub-optimal because h(w) could conflict with fp(w) and f(w). We empirically validate this argument on Vi T-B/32. We search for λ between 0.0 and 0.5 with a step 0.1 and search for ρ in the same grid as SAM and GSAM. We report the best accuracy of each method. Top-1 accuracy in Table 2 show the superior performance of GSAM, validating our analysis. min(fp, h) min(fp, h) min(fp, h) vs. min(f, h) min(f, h) min(f, h) GSAM solves min(fp, h) by descent in fp, decomposing f onto fp, and an ascent step in the orthogonal direction to increase f while keep fp intact. Alternatively, we can also optimize min(f, h) by descent in f, decomposing fp onto f, and a descent step in the orthogonal direction to decrease fp while keep f intact. The two GSAM variations perform similarly (see Fig. 6, right). We choose min(fp, h) mainly to make the minimal change to SAM. GSAM benefits transfer learning Using weights trained on Image Net-1k, we finetune models with SGD on downstream tasks including the CIFAR10/CIFAR100 (Krizhevsky et al., 2009), Oxford- Published as a conference paper at ICLR 2022 Light Medium Strong Top-1 Accuracy (\%) Influence of augmentations Vanilla SAM GSAM Adam Ada Belief Top-1 Accuracy (%) Influence of base optimizers Vanilla SAM GSAM Image Net Image Net-Real Image Net-v2 50 Improvement in Accuracy (%) min(fp, h) v.s. min(f, h) Figure 6: Top-1 accuracy of Vi T-B/32 for the additional studies (Section 6.4). Left: from left to right are performances under different data augmentations (details in Appendix B.3) , where the vanilla method is trained for 2 the epochs. Middle: performance with different base optimizers. Right: Comparison between min(fp, h) and min(f, h). flowers (Nilsback & Zisserman, 2008) and Oxford-IITPets (Parkhi et al., 2012). Results in Table 3 shows that GSAM leads to better transfer performance than vanilla training and SAM. GSAM remains effective under various data augmentations We plot the top-1 accuracy of a Vi T-B/32 model under various Mixup (Zhang et al., 2017b) augmentations in Fig. 6 (left subfigure). Under different augmentations, GSAM consistently outperforms SAM and vanilla training. GSAM is compatible with different base optimizers GSAM is generic and applicable to various base optimizers. We compare vanilla training, SAM and GSAM using Adam W (Loshchilov & Hutter, 2017) and Ada Belief (Zhuang et al., 2020) with default hyper-parameters. Fig. 6 (middle subfigure) shows that GSAM performs the best, and SAM improves over vanilla training. 7 CONCLUSION We propose the surrogate gap as an equivalent measure of sharpness which is easy to compute and feasible to optimize. We propose the GSAM method, which improves the generalization over SAM at negligible computation cost. We show the convergence and provably better generalization of GSAM compared to SAM, and validate the superior performance of GSAM on various models. ACKNOWLEDGEMENT We would like to thank Xiangning Chen (UCLA) and Hossein Mobahi (Google) for discussions, Yi Tay (Google) for help with datasets, and Yeqing Li, Xianzhi Du, and Shawn Wang (Google) for help with Tensor Flow implementation. ETHICS STATEMENT This paper focuses on the development of optimization methodologies and can be applied to the training of different deep neural networks for a wide range of applications. Therefore, the ethical impact of our work would primarily be determined by the specific models that are trained using our new optimization strategy. REPRODUCIBILITY STATEMENT We provide the detailed proof of theoretical results in Appendix A and provide the data preprocessing and hyper-parameter settings in Appendix B. Together with the references to existing works and public codebases, we believe the paper contains sufficient details to ensure reproducibility. We plan to release the models trained by using GSAM upon publication. Randall Balestriero, Jerome Pesenti, and Yann Le Cun. Learning in high dimension always amounts to extrapolation. ar Xiv preprint ar Xiv:2110.09485, 2021. Published as a conference paper at ICLR 2022 Lucas Beyer, Olivier J. Henaff, Alexander Kolesnikov, Xiaohua Zhai, and Aaron van den Oord. Are we done with imagenet? ar Xiv preprint ar Xiv:2002.05709, 2020. L eon Bottou. Large-scale machine learning with stochastic gradient descent. In Proceedings of COMPSTAT 2010, pp. 177 186. Springer, 2010. Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann Le Cun, Carlo Baldassi, Christian Borgs, Jennifer Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12): 124018, 2019. Xiangning Chen, Cho-Jui Hsieh, and Boqing Gong. When vision transformers outperform resnets without pretraining or strong data augmentations, 2021. Ekin D Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V Le. Autoaugment: Learning augmentation policies from data. ar Xiv preprint ar Xiv:1805.09501, 2018. Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat global minimizers. ar Xiv preprint ar Xiv:2106.06530, 2021. Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp. 248 255. Ieee, 2009. Terrance De Vries and Graham W Taylor. Improved regularization of convolutional neural networks with cutout. ar Xiv preprint ar Xiv:1708.04552, 2017. Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. ar Xiv preprint ar Xiv:2010.11929, 2020. John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and stochastic optimization. Journal of machine learning research, 12(Jul):2121 2159, 2011. Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. ar Xiv preprint ar Xiv:2010.01412, 2020. Xavier Gastaldi. Shake-shake regularization. ar Xiv preprint ar Xiv:1705.07485, 2017. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. Byeongho Heo, Sanghyuk Chun, Seong Joon Oh, Dongyoon Han, Sangdoo Yun, Gyuwan Kim, Youngjung Uh, and Jung-Woo Ha. Adamp: Slowing down the slowdown for momentum optimizers on scale-invariant weights. ar Xiv preprint ar Xiv:2006.08217, 2020. Sepp Hochreiter and J urgen Schmidhuber. Simplifying neural nets by discovering flat minima. In Advances in neural information processing systems, pp. 529 536, 1995. Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. ar Xiv preprint ar Xiv:1803.05407, 2018. Arthur Jacot, Franck Gabriel, and Cl ement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. ar Xiv preprint ar Xiv:1806.07572, 2018. Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. ar Xiv preprint ar Xiv:1912.02178, 2019. Nitish Shirish Keskar and Richard Socher. Improving generalization performance by switching from adam to sgd. ar Xiv preprint ar Xiv:1712.07628, 2017. Published as a conference paper at ICLR 2022 Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. ar Xiv preprint ar Xiv:1609.04836, 2016. Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009. Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpnessaware minimization for scale-invariant learning of deep neural networks. ar Xiv preprint ar Xiv:2102.11600, 2021. Beatrice Laurent and Pascal Massart. Adaptive estimation of a quadratic functional by model selection. Annals of Statistics, pp. 1302 1338, 2000. Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization effect of initial large learning rate in training neural networks. ar Xiv preprint ar Xiv:1907.04595, 2019. Tengyuan Liang, Tomaso Poggio, Alexander Rakhlin, and James Stokes. Fisher-rao metric, geometry, and complexity of neural networks. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 888 896. PMLR, 2019. Tao Lin, Lingjing Kong, Sebastian Stich, and Martin Jaggi. Extrapolation for large-batch training in deep learning. In International Conference on Machine Learning, pp. 6094 6104. PMLR, 2020. Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han. On the variance of the adaptive learning rate and beyond. ar Xiv preprint ar Xiv:1908.03265, 2019. Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. ar Xiv preprint ar Xiv:1711.05101, 2017. Liangchen Luo, Yuanhao Xiong, Yan Liu, and Xu Sun. Adaptive gradient methods with dynamic bound of learning rate. ar Xiv preprint ar Xiv:1902.09843, 2019. David Mc Allester. Simplified pac-bayesian margin bounds. In Learning theory and Kernel machines, pp. 203 215. Springer, 2003. David A Mc Allester. Pac-bayesian model averaging. In Proceedings of the twelfth annual conference on Computational learning theory, pp. 164 170, 1999. RV Mises and Hilda Pollaczek-Geiringer. Praktische verfahren der gleichungsaufl osung. ZAMMJournal of Applied Mathematics and Mechanics/Zeitschrift f ur Angewandte Mathematik und Mechanik, 9(1):58 77, 1929. Rafael M uller, Simon Kornblith, and Geoffrey Hinton. When does label smoothing help? ar Xiv preprint ar Xiv:1906.02629, 2019. Behnam Neyshabur, Srinadh Bhojanapalli, and Nathan Srebro. A pac-bayesian approach to spectrally-normalized margin bounds for neural networks. ar Xiv preprint ar Xiv:1707.09564, 2017. Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, pp. 722 729. IEEE, 2008. Omkar M Parkhi, Andrea Vedaldi, Andrew Zisserman, and CV Jawahar. Cats and dogs. In 2012 IEEE conference on computer vision and pattern recognition, pp. 3498 3505. IEEE, 2012. Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, and Vaishaal Shankar. Do imagenet classifiers generalize to imagenet? In International Conference on Machine Learning, pp. 5389 5400, 2019. Sashank J Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of adam and beyond. ar Xiv preprint ar Xiv:1904.09237, 2019. Published as a conference paper at ICLR 2022 David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning internal representations by error propagation. Technical report, California Univ San Diego La Jolla Inst for Cognitive Science, 1985. Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1):1929 1958, 2014. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1 9, 2015. Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, et al. Mlp-mixer: An all-mlp architecture for vision. ar Xiv preprint ar Xiv:2105.01601, 2021. Colin Wei, Jason Lee, Qiang Liu, and Tengyu Ma. Regularization matters: Generalization and optimization of neural nets vs their induced kernel. 2019. Zeke Xie, Li Yuan, Zhanxing Zhu, and Masashi Sugiyama. Positive-negative momentum: Manipulating stochastic gradient noise to improve generalization. ar Xiv preprint ar Xiv:2103.17182, 2021. Xubo Yue, Maher Nouiehed, and Raed Al Kontar. Salr: Sharpness-aware learning rates for improved generalization. ar Xiv preprint ar Xiv:2011.05348, 2020. Manzil Zaheer, Sashank Reddi, Devendra Sachan, Satyen Kale, and Sanjiv Kumar. Adaptive methods for nonconvex optimization. In Advances in neural information processing systems, pp. 9793 9803, 2018. Matthew D Zeiler. Adadelta: an adaptive learning rate method. ar Xiv preprint ar Xiv:1212.5701, 2012. Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. 2017a. Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. ar Xiv preprint ar Xiv:1710.09412, 2017b. Michael Zhang, James Lucas, Jimmy Ba, and Geoffrey E Hinton. Lookahead optimizer: k steps forward, 1 step back. In Advances in Neural Information Processing Systems, pp. 9593 9604, 2019. Yaowei Zheng, Richong Zhang, and Yongyi Mao. Regularizing neural networks via adversarial model perturbation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8156 8165, 2021. Pan Zhou, Jiashi Feng, Chao Ma, Caiming Xiong, Steven Hoi, et al. Towards theoretically understanding why sgd generalizes better than adam in deep learning. ar Xiv preprint ar Xiv:2010.05627, 2020. Juntang Zhuang, Tommy Tang, Yifan Ding, Sekhar Tatikonda, Nicha Dvornek, Xenophon Papademetris, and James S Duncan. Adabelief optimizer: Adapting stepsizes by the belief in observed gradients. ar Xiv preprint ar Xiv:2010.07468, 2020. Published as a conference paper at ICLR 2022 A.1 PROOF OF LEMMA. 3.1 Suppose ρ is small, perform Taylor expansion around the local minima w, we have: f(w + δ) = f(w) + f(w) δ + 1 2δ Hδ + O(||δ||3) (10) where H is the Hessian, and is positive semidefinite at a local minima. At a local minima, f(w) = 0, hence we have f(w + δ) = f(w) + 1 2δ Hδ + O(||δ||3) (11) fp(w) = max||δ|| ρ f(w + δ) = f(w) + 1 2ρ2σmax(H) + O(||δ||3) (12) where σmax is the dominate eigenvalue (eigenvalue with the largest absolute value). Now consider two local minima w1 and w2 with dominate eigenvalue σ1 and σ2 respectively, we have fp(w1) f(w1) + 1 2ρ2σ1 fp(w2) f(w2) + 1 We have fp(w1) > fp(w2) = σ1 > σ2 and σ1 > σ2 = fp(w1) > fp(w2) because the relation between f(w1) and f(w2) is undetermined. A.2 PROOF OF LEMMA. 3.2 Since ρ is small, we can perform Taylor expansion around w, h(w) = f(w + δ) f(w) = δ f(w) + O(ρ2) = ρ|| f(w)||2 + O(ρ2) > 0 (13) where the last line is because δ is approximated as δ = ρ f(w) || f(w)||2+ϵ, hence has the same direction as f(w). A.3 PROOF OF LEMMA. 3.3 Since ρ is small, we can approximate f(w) with a quadratic model around a local minima w: f(w + δ) = f(w) + 1 2δ Hδ + O(ρ3) where H is the Hessian at w, assumed to be positive semidefinite at local minima. Normalize δ such that ||δ||2 = ρ, Hence we have: h(w) = fp(w) f(w) = max||δ||2 ρ f(w + δ) f(w) = 1 2σmaxρ2 + O(ρ3) (14) where σmax is the dominate eigenvalue of the hessian H, and first order term is 0 because the gradient is 0 at local minima. Therefore, we have σmax 2h(w)/ρ2. A.4 PROOF OF THM. 5.1 For simplicity we consider the base optimizer is SGD. For other optimizers such as Adam, we can derive similar results by applying standard proof techniques in the literature to our proof. Published as a conference paper at ICLR 2022 STEP 1: CONVERGENCE W.R.T FUNCTION fp(w) For simplicity of notation, we denote the update at step t as dt = ηtg(t) p + ηtαg(t) (15) By L smoothness of f and the definition of fp(wt) = f(wadv t ), and definition of dt = wt+1 wt and wadv t = wt + δt we have fp(wt+1) = f(wadv t+1) f(wadv t ) + f(wadv t ), wadv t+1 wadv t + L wadv t+1 wadv t 2 (16) = f(wadv t ) + f(wadv t ), wt+1 + δt+1 wt δt wt+1 + δt+1 wt δt 2 (17) f(wadv t ) + f(wadv t ), dt + L dt 2 (18) + f(wadv t ), δt+1 δt + L δt+1 δt 2 (19) STEP 1.0: BOUND EQ. 18 We first bound Eq. 18. Take expectation conditioned on observation up to step t (for simplicity of notation, we use E short for Ex to denote expectation over all possible data points) conditioned on observations up to step t, also by definition of dt, we have Efp(wt+1) fp(wt) ηt fp(wt), Eg(t) p + αηt fp(wt), Eg(t) + Lη2 t E g(t) p + αg(t) 2 ηt E fp(wt) 2 2 + 0 + (α + 1)2G2η2 t (21) Since Eg(t) is orthogonal to fp(wt) by construction, ||g(t)|| G by assumption STEP 1.1: BOUND EQ. 19 By definition of δt, we have δt = ρt g(t) ||g(t)|| + ϵ (22) δt+1 = ρt+1 g(t+1) ||g(t+1)|| + ϵ (23) where g(t) is the gradient of f at wt evaluated with a noisy data sample. When learning rate ηt is small, the update in weight dt is small, and expected gradient is f(wt+1) = f(wt + dt) = f(wt) + Hdt + O(||dt||2) (24) where H is the Hessian at wt. Therefore, we have E f(wadv t ), δt+1 δt = f(wadv t ), ρt E g(t) ||g(t)|| + ϵ ρt+1E g(t+1) ||g(t+1)|| + ϵ (25) || f(wadv t )||ρt E g(t) ||g(t)|| + ϵ E g(t+1) ||g(t+1)|| + ϵ || f(wadv t )||ρtφt (27) where the first inequality is due to (1) ρt is monotonically decreasing with t, and (2) triangle inequality that a, b ||a|| ||b||. φt is the angle between the unit vector in the direction of f(wt) Published as a conference paper at ICLR 2022 and f(wt+1). The second inequality comes from that (1) g ||g||+ϵ < 1 strictly, so we can replace δt in Eq. 25 with a unit vector in corresponding directions multiplied by ρt and get the upper bound, (2) the norm of difference in unit vectors can be upper bounded by the arc length on a unit circle. When learning rate ηt and update stepsize dt is small, φt is also small. Using the limit that tan x = x + O(x2), sin x = x + O(x2), x 0 tan φt = || f(wt+1) f(wt)|| || f(wt)|| + O(φ2 t) (28) = ||Hdt + O(||dt||2)|| || f(wt)|| + O(φ2 t) (29) ηt L(1 + α) (30) where the last inequality is due to (1) max eigenvalue of H is upper bounded by L because f is L smooth, (2) ||dt|| = ||ηt(g + αg )|| and Egt = f(wt). Plug into Eq. 27, also note that the perturbation amplitude ρt is small so wt is close to wadv t , then we have E f(wadv t ), δt+1 δt L(1 + α)Gρtηt (31) Similarly, we have E δt+1 δt 2 ρ2 t E g(t) ||g(t)|| + ϵ g(t+1) ||g(t+1)|| + ϵ ρ2 tφ2 t (33) ρ2 tη2 t L2(1 + α)2 (34) STEP 1.2: TOTAL BOUND Reuse results from Eq. 21 (replace Lp with 2L) and plug into Eq. 18, and plug Eq. 31 and Eq. 34 into Eq. 19, we have Efp(wt+1) fp(wt) ηt E fp(wt) 2 2 + 2L(α + 1)2 + L(1 + α)Gρtηt + 2L3(1 + α)2 2 η2 t ρ2 t (35) Perform telescope sum, we have Efp(w T ) fp(w0) t=1 ηt E|| fp(wt)||2 + h L(1 + α)2G2η2 0 + L(1 + α)Gρ0η0 i T X + L3(1 + α)2η2 0ρ2 0 t=1 E|| fp(wt)||2 t=1 ηt E|| fp(wt)||2 fp(w0) Efp(w T ) + D log T + π2E where D = L(1 + α)2G2η2 0 + L(1 + α)Gρ0η0, E = L3(1 + α)2η2 0ρ2 0 (38) Note that ηT = η0 T , we have t=1 E|| fp(wt)||2 fp(w0) fmin + π2E/6 which implies that GSAM enables fp to converge at a rate of O(log T/ T), and all the constants here are well-bounded. Published as a conference paper at ICLR 2022 STEP 2: CONVERGENCE W.R.T. FUNCTION f(w) We prove the risk for f(w) convergences for non-convex stochastic optimization case using SGD. Denote the update at step t as dt = ηtg(t) p + αηtg(t) (40) By smoothness of f, we have f(wt+1) f(wt) + f(wt), dt + L = f(wt) + f(wt), ηtg(t) p + αηtg(t) + L For simplicity, we introduce a scalar βt such that f (wt) = βt fp(wt) (43) where f (wt) is the projection of f(wt) onto fp(wt). When perturbation amplitude ρ is small, we expect βt to be very close to 1. Take expectation conditioned on observations up to step t for both sides of Eq. 42, we have: Ef(wt+1) f(wt) + f(wt) f (wt) + αηt Eg(t) βt + α ηt D f(wt), f (wt) E + L βt + α ηt D f(wt), f(wt) sin θt E + L 2 (46) θt is the angle between fp(wt) and f(wt) βt + α ηt f(wt) 2 2(| tan θt| + O(θ2 t )) + L 2 (47) sin x = x + O(x2), tan x = x + O(x2) when x 0. Also note when perturbation amplitude ρt is small, we have fp(wt) = f(wt + δt) = f(wt) + ρt || f(wt)||2 + ϵH(wt) f(wt) + O(ρ2 t) (48) where δt = ρt f(wt) || f(wt)||2 by definition, H(wt) is the Hessian. Hence we have | tan θt| || fp(wt) f(wt)|| || f(wt)|| ρt L || f(wt)|| (49) where L is the Lipschitz constant of f, and L smoothness of f indicates the maximum absolute eigenvalue of H is upper bounded by L. Plug Eq. 49 into Eq. 47, we have Ef(wt+1) f(wt) ηt βt + α ηt f(wt) 2 2| tan θt| + L βt + α Lρtηt f(wt) 2 + L βt + α Lρtηt G + L 2 (52) Assume gradient has bounded norm G. (53) f(wt) ηt βmax 2 + 1 βmin + α Lρtηt G + L 2 E(α + 1)2G2η2 t (54) βt is close to 1 assuming ρ is small, hence it s natural to assume 0 < βmin βt βmax Published as a conference paper at ICLR 2022 Re-arranging above formula, we have ηt βmax 2 f(wt) Ef(wt+1) + 1 βmin + α LGηtρt + L 2 (α + 1)2G2η2 t (55) perform telescope sum and taking expectations on each step, we have t=1 ηt f(wt) 2 2 f(w0) Ef(w T ) + 1 βmin + α LG t=1 ηtρt + L 2 (α + 1)2G2 T X (56) Take the schedule to be ηt = η0 t and ρt = ρ0 t, then we have f(w0) fmin + 1 βmin + α LGη0ρ0 2 (α + 1)2G2η2 0 f(w0) fmin + 1 βmin + α LGη0ρ0(1 + log T) 2 (α + 1)2G2η2 0(1 + log T) (60) T + C4 log T where C1, C4 are some constants. This implies the convergence rate w.r.t f(w) is O(log T/ STEP 3: CONVERGENCE W.R.T. SURROGATE GAP h(w) Note that we have proved convergence for fp(w) in step 1, and convergence for f(w) in step 3. Also note that h(wt) 2 2 = fp(wt) f(wt) 2 2 2 fp(wt) 2 2 + 2 f(wt) 2 also converges at rate O(log T/ T) because each item in the RHS converges at rate O(log T A.5 PROOF OF COROLLARY. 5.2.1 Using the results from Thm. 5.2, with probability at least 1 a, we have Ew QExf(w, x) Ew Q bf(w) + 4 KL(Q||P) + log 2m Assume δ N(0, b2Ik) where k is the dimension of model parameters, hence δ2 (element-wise square) follows a a Chi-square distribution. By Lemma.1 in Laurent & Massart (2000), we have P ||δ||2 2 kb2 2b2 kt + 2tb2 exp( t) (65) hence with probability at least 1 1/ n, we have ||δ||2 2 b2 2 log n + k + 2 q Published as a conference paper at ICLR 2022 Therefore, with probability at least 1 1/ n = 1 exp ρ Eδ bf(w + δ) max||δ||2 ρ bf(w + δ) (67) Combine Eq. 65 and Eq. 67, subtract the same constant C on both sides, and under the same assumption as in (Foret et al., 2020) that Ew QExf(w, x) Eδ N(0,b2Ik)Ew QExf(w + δ, x)we finish the proof. A.6 PROOF OF THM. 5.3 STEP 1: A SUFFICIENT CONDITION THAT THE LOSS GAP IS EXPECTED TO DECREASE FOR EACH STEP Take Taylor expansion, then the expected change of loss gap caused by descent step is E fp(wt) f(wt), ηt fp(wt) (68) where Eg = f (wt) fp(wt) 2 2 + fp(wt) 2 f(wt) 2 cos θt where θt is the angle between vector fp(wt) and f(wt). The expected change of loss gap caused by ascent step is E fp(wt) f(wt), αηt f (wt) = αηt f (wt) 2 2 < 0 (70) Above results demonstrate that ascent step decreases the loss gap, while descent step might increase the loss gap. A sufficient (but not necessary) condition for E h(wt), dt 0 requires α to be large or | f(wt) 2 cos θt fp(wt) . In practice, the perturbation amplitude ρ is small and we can assume θt is close to 0 and fp(wt) is close to f(wt) , we can also set the parameter α to be large in order to decrease the loss gap. STEP 2: UPPER AND LOWER BOUND OF DECREASE IN LOSS GAP (BY THE ASCENT STEP IN ORTHOGONAL GRADIENT DIRECTION) COMPARED TO SAM. Next we give an estimate of the decrease in bh caused by our ascent step. We refer to Eq. 69 and Eq. 70 to analyze the change in loss gap caused by the descent and ascent (orthogonally) respectively. It can be seen that gradient descent step might not decrease loss gap, in fact they often increase loss gap in practice; while the ascent step is guaranteed to decrease the loss gap. The decrease in loss gap is: bht = bfp(wt) bf(wt), αηt bf (wt) = αηt bf (wt) 2 2 (71) = αηt bf(wt) 2 2| tan θt|2 (72) t=1 αL2ηtρ2 t (73) By Eq. 49 (74) t=1 αL2η0ρ2 0 1 t3/2 (75) 2.7αL2η0ρ2 0 (76) Hence we derive an upper bound for PT t=1 bht. Published as a conference paper at ICLR 2022 Next we derive a lower bound for PT t=1 bht Note that when ρt is small, by Taylor expansion bfp(wt) = bf(wt + δt) = bf(wt) + ρt || bf(wt)|| b H(wt) bf(wt) + O(ρ2 t) (77) where b H(wt) is the Hessian evaluated on training samples. Also when ρt is small, the angle θt between bfp(wt) and bf(wt) is small, by the limit that tan x = x + O(x2), x 0 sin x = x + O(x2), x 0 We have | tan θt| = | sin θt| + O(θ2 t ) = |θt| + O(θ2 t ) Omitting high order term, we have | tan θt| |θt| = || bfp(wt) bf(wt)|| || bf(wt)|| = ||ρt b H(wt) + O(ρ2 t)|| || bf(wt)|| ρt|σ|min where G is the upper-bound on norm of gradient, |σ|min is the minimum absolute eigenvalue of the Hessian. The intuition is that as perturbation amplitude decreases, the angle θt decreases at a similar rate, though the scale constant might be different. Hence we have t=1 αηt bf(wt) 2 2| tan θt|2 + O(θ4 t ) (79) t=1 αηtc2 ρt|σ|min = αc2ρ2 0η0|σ|2 min G2 1 t3/2 (81) αc2ρ2 0η0|σ|2 min G2 (82) where c2 is the lower bound of || bf||2 (e.g. due to noise in data and gradient observation). Results above indicate that the decrease in loss gap caused by the ascent step is non-trivial, hence our proposed method efficiently improves generalization compared with SAM. A.7 DISCUSSION ON COROLLARY 5.2.1 The comment The corollary gives a bound on the risk in terms of the perturbed training loss if one removes C from both sides is correct. But there is a misunderstanding in the statement the perturbed training loss is small then the model has a small risk : it s only true when ρtrain for training equals its real value ρtrue determined by the data distribution; in practice, we never know ρtrue. In the following we show that the minimization of both h and fp is better than simply minimizing fp when ρtrue = ρtrain. 1. First, we re-write the conclusion of Corollary 5.2.1 as Ew Exf(w, x) fp + R = C + bh + R = C + ρ2σ/2 + R + O(ρ3) with probability (1 a)[1 e ( ρ where R is the regularization term, C is the training loss, σ is the dominant eigenvalue of Hessian. As in lemma 3.3, we perform Taylor-expansion and can ignore the high-order term O(ρ3). We focus on fp = C + bh = C + ρ2σ/2 2. When ρtrue = ρtrain, minimizing h achieves a lower risk than only minimizing fp. (1) Note that after training, C (training loss) is fixed, but h could vary with ρ (e.g. when training on dataset A and testing on an unrelated dataset B, the training loss remains unchanged, but the risk would be huge and a large ρ is required for a valid bound). (2) With an example, we show a low fp is insufficient for generalization, and a low σ is necessary: Published as a conference paper at ICLR 2022 A Suppose we use ρtrain for training, and consider two solutions with C1, σ1 (SAM) and C2, σ2 (GSAM). Suppose they have the same fp during training for some ρtrain, so fp1 = C1 + σ1/2 ρ2 train = C2 + σ2/2 ρ2 train = fp2 Suppose C1 < C2 so σ1 > σ2. B When ρtrue > ρtrain, we have risk bound 1 = C1 + σ1/2 ρ2 true + R > risk bound 2 = C2 + σ2/2 ρ2 true + R This implies that a small σ helps generalization, but only a low fp1 (caused by a low C1 and high σ1) is insufficient for a good generalization. C Note that ρtrain is fixed during training, so minimizing htrain during training is equivalently minimizing σ by Lemma 3.3 3. Why we are often unlucky to have ρtrue > ρtrain (1) First, the test sets are almost surely outside the convex hull of the training set because interpolation almost surely never occurs in high-dimensional (> 100) cases Balestriero et al. (2021). As a result, the variability of (train + test) sets is almost surely larger than the variability of (train) set. Since ρ increases with data variability (see point 4 below), we have ρtrue > ρtrain set almost surely. (2) Second, we don t know the value of ρtrue and can only guess it. In practice, we often guess a small value because training often diverges with large ρ (as observed in Foret et al. (2020); Chen et al. (2021)). 4. Why ρ increases with data variability. In Corollary 5.2.1, we assume weight perturbation δ N(0, b2Ik). The meaning of b is the following. If we can randomly sample a fixed number of samples from the underlying distribution, then training the model from scratch (with a fixed seed for random initialization) gives rise to a set of weights. Repeating this process, we get many sets of weights, and their standard deviation is b. Since the number of training samples is limited and fixed, the more variability in data, the more variability in weights, and the larger b. Note that Corollary stated that the bound holds with probability proportional to [1 e ( ρ k)2 ]. In order for the result to hold with a fixed probability, ρ must stay proportional to b, hence ρ also increases with the variability of data. Published as a conference paper at ICLR 2022 Table 4: Hyper-parameters to reproduce experimental results Model ρmax ρmin α lrmax lrmin Weight Decay Base Optimizer Epochs Warmup Steps LR schedule Res Net50 0.04 0.02 0.01 1.6 1.6e-2 0.3 SGD 90 5k Linear Res Net101 0.04 0.02 0.01 1.6 1.6e-2 0.3 SGD 90 5k Linear Res Net512 0.04 0.02 0.005 1.6 1.6e-2 0.3 SGD 90 5k Linear Vi T-S/32 0.6 0.0 0.4 3e-3 3e-5 0.3 Adam W 300 10k Linear Vi T-S/16 0.6 0.0 1.0 3e-3 3e-5 0.3 Adam W 300 10k Linear Vi T-B/32 0.6 0.1 0.6 3e-3 3e-5 0.3 Adam W 300 10k Linear Vi T-B/16 0.6 0.2 0.4 3e-3 3e-5 0.3 Adam W 300 10k Linear Mixer-S/32 0.5 0.0 0.2 3e-3 3e-5 0.3 Adam W 300 10k Linear Mixer-S/16 0.5 0.0 0.6 3e-3 3e-5 0.3 Adam W 300 10k Linear Mixer-S/8 0.5 0.1 0.1 3e-3 3e-5 0.3 Adam W 300 10k Linear Mixer-B/32 0.7 0.2 0.05 3e-3 3e-5 0.3 Adam W 300 10k Linear Mixer-B/16 0.5 0.2 0.01 3e-3 3e-5 0.3 Adam W 300 10k Linear B EXPERIMENTAL DETAILS B.1 TRAINING DETAILS For Vi T and Mixer, we search the learning rate in {1e-3, 3e-3, 1e-2, 3e-3}, and search weight decay in {0.003, 0.03, 0.3}. For Res Net, we search the learning rate in {1.6, 0.16, 0.016}, and search the weight decay in {0.001, 0.01,0.1}. For Vi T and Mixer, we use the Adam W optimizer with β1 = 0.9, β2 = 0.999; for Res Net we use SGD with momentum= 0.9. We train Res Nets for 90 epochs, and train Vi Ts and Mixers for 300 epochs following the settings in (Chen et al., 2021) and (Dosovitskiy et al., 2020). Considering that SAM and GSAM uses twice the computation of vanilla training for each step, for vanilla training we try 2 longer training, and does not find significant improvement as in Table. 5. We first search the optimal learning rate and weight decay for vanilla training, and keep these two hyper-parameters fixed for SAM and GSAM. For Vi T and Mixer, we search ρ in {0.1, 0.2, 0.3, 0.4, 0.5, 0.6} for SAM and GSAM; for Res Net, we search ρ from 0.01 to 0.05 with a stepsize 0.01. For ASAM, we amplify ρ by 10 compared to SAM, as recommended by Kwon et al. (2021). For GSAM, we search α in {0.1, 0.2, 0.3} throughout the paper. We report the best configuration of each individual model in Table. 4. B.2 TRANSFER LEARNING EXPERIMENTS Using weights trained on Image Net-1k, we finetune models with SGD on downstream tasks including the CIFAR10/CIFAR100 (Krizhevsky et al., 2009), Oxford-flowers (Nilsback & Zisserman, 2008) and Oxford-IITPets (Parkhi et al., 2012). For all experiments, we use the SGD optimizer with no weight decay under a linear learning rate schedule and gradient clipping with global norm 1. We search the maximum learning rate in {0.001, 0.003, 0.01, 0.03}. On Cifar datasets, we train models for 10k steps with a warmup step of 500; on Oxford datasets, we train models for 500 steps with a wamup step of 100. B.3 EXPERIMENTAL SETUP WITH ABLATION STUDIES ON DATA AUGMENTATION We follow the settings in (Tolstikhin et al., 2021) to perform ablation studies on data augmentation. In the left subfigure of Fig. 6, Light refers to Inception-style data augmentation with random flip and crop of images, Medium refers to the mixup augmentation with probability 0.2 and Rand Aug magnitude 10; Strong refers to the mixup augmentation with probability 0.2 and Rand Aug magnitude 15. C ABLATION STUDIES AND DISCUSSIONS C.1 INFLUENCE OF ρ AND α We plot the performance of a Vi T-B/32 model varying with ρ (Fig. 7a) and α (Fig. 7b). We empirically validate that fine-tuning ρ in SAM can not achieve comparable performance with GSAM, as Published as a conference paper at ICLR 2022 = 0.1 = 0.2 = 0.3 = 0.4 = 0.5 = 0.6 69 Top1-Accuracy (%) Top-1 accuracy of Vi T-B/32 under different (a) Performance of SAM and GSAM under different ρ. = 0.05 = 0.10 = 0.15 = 0.20 = 0.25 69 Top-1 Accuracy(%) Top-1 accuracy of Vi T-B/32 varying with Vanilla SAM (b) Performance of GSAM under different α Figure 7: Performance of GSAM varying with ρ and α. Table 5: Top-1 accuracy of Vi T-B/32 on Image Net with Inception-style data augmentation. For vanilla training we report results for training 300 epochs and 600 epochs, for GSAM we report the results for 300 epochs. Method Epochs Image Net Image Net-Real Image Net-v2 Image Net-R Vanilla 300 71.4 77.5 57.5 23.4 600 72.0 78.2 57.9 23.6 GSAM 300 76.8 82.7 63.0 25.1 shown in Fig. 7a. Considering that GSAM has one more parameter α, we plot the accuracy varying with α in Fig. 7b, and show that GSAM consistently outperforms SAM and vanilla training. C.2 CONSTANT ρ V.S. DECAYED ρt SCHEDULE Note that Thm. 5.1 assumes ρt to decay with t in order to prove the convergence, while SAM uses a constant ρ during training. To eliminate the influence of ρt schedule, we conduct ablation study as in Table. 6. The ascent step in GSAM can be applied to both constant ρ or a decayed ρt schedule, and improves accuracy for both cases. Without ascent step, constant ρ and decayed ρt achieve similar performance. Results in Table. 6 implies that the ascent step in GSAM is the main reason for improvement of generalization performance. 0 250 500 750 1000 1250 1500 1750 Training step =0 =0.05 =0.1 =0.15 =0.2 Figure 8: The value of cos θt varying with training steps, where θt is the angle between f(wt) and fp(wt) as in Fig. 2. 0 250 500 750 1000 1250 1500 1750 Training step Surrogate gap h Influence of on surrogate gap h =0 =0.05 =0.1 =0.15 =0.2 Figure 9: Surrogate gap curve under different α values. Published as a conference paper at ICLR 2022 Table 6: Top-1 Accuracy on Vi T-B/32 on Image Net. Ablation studies on constant ρ or a decayed ρt. Vanilla Constant ρ (SAM) Constant ρ + ascent Decayed ρt Decayed ρt + ascent 72.0 75.8 76.2 75.8 76.8 C.3 VISUALIZE THE TRAINING PROCESS In the proof of Thm. 5.3, our analysis relies on assumption that θt is small. We empirically validated this assumption by plotting cos θt in Fig. 8, where θt is the angle between f(wt) and fp(wt). Note that the cosine value is calculated in the parameter space of dimension 8.8 107, and in high-dimensional space two random vectors are highly likely to be perpendicular. In Fig. 8 the cosine value is always above 0.9, indicating that f(wt) and fp(wt) point to very close directions considering the high dimension of parameters. This empirically validates our assumption that θt is small during training. We also plot the surrogate gap during training in Fig. 9. As α increases, the surrogate gap decreases, validating that the ascent step in GSAM efficiently minimizes the surrogate gap. Furthermore, the surrogate gap increases with training steps for any fixed α, indicating that the training process gradually falls into local minimum in order to minimize the training loss. D RELATED WORKS Besides SAM and ASAM, other methods were proposed in the literature to improve generalization: Lin et al. (2020) proposed extrapolation of gradient, Xie et al. (2021) proposed to manipulate the noise in gradient, and Damian et al. (2021) proved label noise improves generalization, Yue et al. (2020) proposed to adjust learning rate according to sharpness, and Zheng et al. (2021) proposed model perturbation with similar idea to SAM. Izmailov et al. (2018) proposed averaging weights to improve generalization, and Heo et al. (2020) restricted the norm of updated weights to improve generalization. Many of aforementioned methods can be combined with GSAM to further improve generalization. Besides modified training schemes, there are other two types of techniques to improve generalization: data augmentation and model regularization. Data augmentation typically generates new data from training samples; besides standard data augmentation such as flipping or rotation of images, recent data augmentations include label smoothing (M uller et al., 2019) and mixup (M uller et al., 2019) which trains on convex combinations of both inputs and labels, automatically learned augmentation (Cubuk et al., 2018), and cutout (De Vries & Taylor, 2017) which randomly masks out parts of an image. Model regularization typically applies auxiliary losses besides the training loss such as weight decay (Loshchilov & Hutter, 2017), other methods randomly modify the model architecture during training, such as dropout (Srivastava et al., 2014) and shake-shake regularization (Gastaldi, 2017). Note that the data augmentation and model regularization literature mentioned here typically train with the standard back-propagation (Rumelhart et al., 1985) and first-order gradient optimizers, and both techniques can be combined with GSAM. Besides SGD, Adam and Ada Belief, GSAM can be combined with other first-order gradient optimizers, such as Ada Bound (Luo et al., 2019), RAdam (Liu et al., 2019), Yogi (Zaheer et al., 2018), Ada Grad (Duchi et al., 2011), AMSGrad (Reddi et al., 2019) and Ada Delta (Zeiler, 2012).