# an_adaptive_policy_to_employ_sharpnessaware_minimization__ce5a44f7.pdf Published as a conference paper at ICLR 2023 AN ADAPTIVE POLICY TO EMPLOY SHARPNESS-AWARE MINIMIZATION Weisen Jiang1, 2, Hansi Yang2, Yu Zhang1, 3, , James Kwok2 1 Guangdong Provincial Key Laboratory of Brain-inspired Intelligent Computation Department of Computer Science and Engineering, Southern University of Science and Technology 2 Department of Computer Science and Engineering, Hong Kong University of Science and Technology 3 Peng Cheng Laboratory {wjiangar, hyangbw, jamesk}@cse.ust.hk, yu.zhang.ust@gmail.com Sharpness-aware minimization (SAM), which searches for flat minima by min-max optimization, has been shown to be useful in improving model generalization. However, since each SAM update requires computing two gradients, its computational cost and training time are both doubled compared to standard empirical risk minimization (ERM). Recent state-of-the-arts reduce the fraction of SAM updates and thus accelerate SAM by switching between SAM and ERM updates randomly or periodically. In this paper, we design an adaptive policy to employ SAM based on the loss landscape geometry. Two efficient algorithms, AE-SAM and AE-Look SAM, are proposed. We theoretically show that AE-SAM has the same convergence rate as SAM. Experimental results on various datasets and architectures demonstrate the efficiency and effectiveness of the adaptive policy. 1 INTRODUCTION Despite great success in many applications (He et al., 2016; Zagoruyko & Komodakis, 2016; Han et al., 2017), deep networks are often over-parameterized and capable of memorizing all training data. The training loss landscape is complex and nonconvex with many local minima of different generalization abilities. Many studies have investigated the relationship between the loss surface s geometry and generalization performance (Hochreiter & Schmidhuber, 1994; Mc Allester, 1999; Keskar et al., 2017; Neyshabur et al., 2017; Jiang et al., 2020), and found that flatter minima generalize better than sharper minima (Dziugaite & Roy, 2017; Petzka et al., 2021; Chaudhari et al., 2017; Keskar et al., 2017; Jiang et al., 2020). Sharpness-aware minimization (SAM) (Foret et al., 2021) is the current state-of-the-art to seek flat minima by solving a min-max optimization problem. In the SAM algorithm, each update consists of two forward-backward computations: one for computing the perturbation and the other for computing the actual update direction. Since these two computations are not parallelizable, SAM doubles the computational overhead as well as the training time compared to empirical risk minimization (ERM). Several algorithms (Du et al., 2022a; Zhao et al., 2022b; Liu et al., 2022) have been proposed to improve the efficiency of SAM. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and updates fewer parameters, but each update still requires two gradient computations. Thus, ESAM does not alleviate the bottleneck of training speed. Instead of using the SAM update at every iteration, recent state-of-the-arts (Zhao et al., 2022b; Liu et al., 2022) proposed to use SAM randomly or periodically. Specifically, SS-SAM (Zhao et al., 2022b) selects SAM or ERM according to a Bernoulli trial, while Look SAM (Liu et al., 2022) employs SAM at every k step. Though more efficient, the random or periodic use of SAM is suboptimal as it is not geometry-aware. Intuitively, the SAM update is more useful in sharp regions than in flat regions. In this paper, we propose an adaptive policy to employ SAM based on the geometry of the loss landscape. The SAM update is used when the model is in sharp regions, while the ERM update is used in flat regions for reducing the fraction of SAM updates. To measure sharpness, we use Correspondence to: Yu Zhang Published as a conference paper at ICLR 2023 the squared stochastic gradient norm and model it by a normal distribution, whose parameters are estimated by exponential moving average. Experimental results on standard benchmark datasets demonstrate the superiority of the proposed policy. Our contributions are summarized as follows: (i) We propose an adaptive policy to use SAM or ERM update based on the loss landscape geometry. (ii) We propose an efficient algorithm, called AE-SAM (Adaptive policy to Employ SAM), to reduce the fraction of SAM updates. We also theoretically study its convergence rate. (iii) The proposed policy is general and can be combined with any SAM variant. In this paper, we integrate it with Look SAM (Liu et al., 2022) and propose AE-Look SAM. (iv) Experimental results on various network architectures and datasets (with and without label noise) verify the superiority of AE-SAM and AE-Look SAM over existing baselines. Notations. Vectors (e.g., x) and matrices (e.g., X) are denoted by lowercase and uppercase boldface letters, respectively. For a vector x, its ℓ2-norm is x . N(µ; σ2) is the univariate normal distribution with mean µ and variance σ2. diag(x) constructs a diagonal matrix with x on the diagonal. Moreover, IA(x) denotes the indicator function for a given set A, i.e., IA(x) = 1 if x A, and 0 otherwise. 2 RELATED WORK We are given a training set D with i.i.d. samples {(xi, yi) : i = 1, . . . , n}. Let f(x; w) be a model parameterized by w. Its empirical risk on D is L(D; w) = 1 n Pn i=1 ℓ(f(xi; w), yi), where ℓ( , ) is a loss (e.g., cross-entropy loss for classification). Model training aims to learn a model from the training data that generalizes well on the test data. Generalization and Flat Minima. The connection between model generalization and loss landscape geometry has been theoretically and empirically studied in (Keskar et al., 2017; Dziugaite & Roy, 2017; Jiang et al., 2020). Recently, Jiang et al. (2020) conducted large-scale experiments and find that sharpness-based measures (flatness) are related to generalization of minimizers. Although flatness can be characterized by the Hessian s eigenvalues (Keskar et al., 2017; Dinh et al., 2017), handling the Hessian explicitly is computationally prohibitive. To address this issue, practical algorithms propose to seek flat minima by injecting noise into the optimizers (Zhu et al., 2019; Zhou et al., 2019; Orvieto et al., 2022; Bisla et al., 2022), introducing regularization (Chaudhari et al., 2017; Zhao et al., 2022a; Du et al., 2022b), averaging model weights during training (Izmailov et al., 2018; He et al., 2019; Cha et al., 2021), or sharpness-aware minimization (SAM) (Foret et al., 2021; Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022). SAM. The state-of-the-art SAM (Foret et al., 2021) and its variants (Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022; Zhao et al., 2022a) search for flat minima by solving the following min-max optimization problem: min w max ϵ ρ L(D; w + ϵ), (1) where ρ > 0 is the radius of perturbation. The above can also be rewritten as minw L(D; w) + R(D; w), where R(D; w) max ϵ ρ L(D; w+ϵ) L(D; w) is a regularizer that penalizes sharp minimizers (Foret et al., 2021). As solving the inner maximization in (1) exactly is computationally infeasible for nonconvex losses, SAM approximately solves it by first-order Taylor approximation, leading to the update rule: wt+1 = wt η L(Bt; wt + ρt L(Bt; wt)), (2) where Bt is a mini-batch of data, η is the step size, and ρt = ρ L(Bt;wt) . Although SAM has shown to be effective in improving the generalization of deep networks, a major drawback is that each update in (2) requires two forward-backward calculations. Specifically, SAM first calculates the gradient of L(Bt; w) at wt to obtain the perturbation, then calculates the gradient of L(Bt; w) at wt+ρt L(Bt; wt) to obtain the update direction for wt. As a result, SAM doubles the computational overhead compared to ERM. Efficient Variants of SAM. Several algorithms have been proposed to accelerate the SAM algorithm. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and only updates part of the model in the second step, but still requires to compute most of the gradients. Another direction is to reduce the number of SAM updates during training. SS-SAM (Zhao et al., 2022b) randomly selects Published as a conference paper at ICLR 2023 0 50 100 150 200 epoch variance of gradients ERM SAM AE-SAM (a) Res Net-18. 0 50 100 150 200 epoch variance of gradients ERM SAM AE-SAM (b) WRN-28-10. 0 50 100 150 200 250 300 epoch variance of gradients ERM SAM AE-SAM (c) Pyramid Net-110. Figure 1: Variance of gradient on CIFAR-100. Best viewed in color. SAM or ERM update according to a Bernoulli trial, while Look SAM (Liu et al., 2022) employs SAM at every k iterations. Intuitively, the SAM update is more suitable for sharp regions than flat regions. However, the mixing policies in SS-SAM and Look SAM are not adaptive to the loss landscape. In this paper, we design an adaptive policy to employ SAM based on the loss landscape geometry. In this section, we propose an adaptive policy to employ SAM. The idea is to use ERM when wt is in a flat region, and use SAM only when the loss landscape is locally sharp. We start by introducing a sharpness measure (Section 3.1), then propose an adaptive policy based on this (Section 3.2). Next, we propose two algorithms (AE-SAM and AE-Look SAM) and study the convergence. 3.1 SHARPNESS MEASURE Though sharpness can be characterized by Hessian s eigenvalues (Keskar et al., 2017; Dinh et al., 2017), they are expensive to compute. A widely-used approximation is based on the gradient magnitude diag([ L(Bt; wt)]2) (Bottou et al., 2018; Khan et al., 2018), where [v]2 denotes the elementwise square of a vector v. As L(Bt; wt) 2 equals the trace of diag([ L(Bt; wt)]2), it is reasonable to choose L(Bt; wt) 2 as a sharpness measure. L(Bt; wt) 2 is also related to the gradient variance Var( L(Bt; wt)), another sharpness measure (Jiang et al., 2020). Specifically, Var( L(Bt; wt)) EBt L(Bt; wt) L(D; wt) 2 =EBt L(Bt; wt) 2 L(D; wt) 2. (3) With appropriate smoothness assumptions on L, both SAM and ERM can be shown theoretically to converge to critical points of L(D; w) (i.e., L(D; w) = 0) (Reddi et al., 2016; Andriushchenko & Flammarion, 2022). Thus, it follows from (3) that Var( L(Bt; wt)) = EBt L(Bt; wt) 2 when wt is a critical point of L(D; w). Jiang et al. (2020) conducted extensive experiments and empirically show that Var( L(Bt; wt)) is positively correlated with the generalization gap. The smaller the Var( L(Bt; wt)), the better generalization is the model with parameter wt. This finding also explains why SAM generalizes better than ERM. Figure 1 shows the gradient variance w.r.t. the number of epochs using SAM and ERM on CIFAR-100 with various network architectures (experimental details are in Section 4.1). As can be seen, SAM always has a much smaller variance than ERM. Figure 2 shows the expected squared norm of the stochastic gradient w.r.t. the number of epochs on CIFAR-100. As shown, SAM achieves a much smaller EBt L(Bt; wt) 2 than ERM. 3.2 ADAPTIVE POLICY TO EMPLOY SAM As EBt L(Bt; wt) 2 changes with t (Figure 2), the sharpness at wt also changes along the optimization trajectory. As a result, we need to estimate EBt L(Bt; wt) 2 at every iteration. One can sample a large number of mini-batches and compute the mean of the stochastic gradient norms. However, this can be computationally expensive. To address this problem, we model L(Bt; wt) 2 with a simple distribution and estimate the distribution parameters in an online manner. Figure 3(a) shows L(Bt; wt) 2 of 400 mini-batches at different training stages (epoch = 60, 120, and 180) Published as a conference paper at ICLR 2023 0 50 100 150 200 epoch t ( t, wt) 2 ERM SAM AE-SAM (a) Res Net-18. 0 50 100 150 200 epoch t ( t, wt) 2 ERM SAM AE-SAM (b) WRN-28-10. 0 50 100 150 200 250 300 epoch t ( t, wt) 2 ERM SAM AE-SAM (c) Pyramid Net-110. Figure 2: Squared stochastic gradient norms EB L(B; wt) 2 on CIFAR-100. Best viewed in color. on CIFAR-100 using Res Net-181. As can be seen, the distribution follows a Bell curve. Figure 3(b) shows the corresponding quantile-quantile (Q-Q) plot (Wilk & Gnanadesikan, 1968). The closer is the curve to a line, the distribution is closer to the normal distribution. Figure 3 suggests that L(Bt; wt) 2 can be modeled2 with a normal distribution N(µt, σ2 t ). We use exponential moving average (EMA), which is popularly used in adaptive gradient methods (e.g., RMSProp (Tieleman & Hinton, 2012), Ada Delta (Zeiler, 2012), Adam (Kingma & Ba, 2015)), to estimate its mean and variance: µt = δµt 1 + (1 δ) L(Bt; wt) 2, (4) σ2 t = δσ2 t 1 + (1 δ)( L(Bt; wt) 2 µt)2, (5) where δ (0, 1) controls the forgetting rate. Empirically, we use δ = 0.9. Since L(Bt; wt) is already available during training, this EMA update does not involve additional gradient calculations (the cost for the norm operator is negligible). 0 10 20 30 40 50 ( t; wt) 2 epoch = 180 epoch = 120 (a) Distributions. 2 0 2 theoretical quantiles sample quantiles epoch=180 epoch=120 epoch=60 (b) Q-Q plots. Figure 3: Stochastic gradient norms { L(Bt; wt) 2 : Bt D} of Res Net-18 on CIFAR-100 are approximately normally distributed. Best viewed in color. Using µt and σ2 t , we employ SAM only at iterations where L(Bt; wt) 2 is relatively large (i.e., the loss landscape is locally sharp). Specifically, when L(Bt; wt) 2 µt + ctσt (where ct is a threshold), SAM is used; otherwise, ERM is used. When ct , it reduces to SAM; when ct , it becomes ERM. Note that during the early training stage, the model is still underfitting and wt is far from the region of final convergence. Thus, minimizing the empirical loss is more important than seeking a locally flat region. Andriushchenko & Flammarion (2022) also empirically observe that the SAM update is more effective in boosting performance towards the end of training. We therefore design a schedule that linearly decreases ct from λ2 to λ1 (which are pre-set values): ct = gλ1,λ2(t) t T λ2, where T is the total number of iterations. The whole procedure, called Adaptive policy to Employ SAM (AE-SAM), is shown in Algorithm 1. AE-Look SAM. The proposed adaptive policy can be combined with any SAM variant. Here, we consider integrating it with Look SAM (Liu et al., 2022). When L(Bt; wt) 2 µt + ctσt, SAM 1Results on other architectures and CIFAR-10 are shown in Figures 8 and 9 of Appendix B.1. 2Note that normality is not needed in the theoretical analysis (Section 3.3). Published as a conference paper at ICLR 2023 is used and the update direction for wt is decomposed into two orthogonal directions as in Look SAM: (i) the ERM update direction to reduce training loss, and (ii) the direction that biases the model to a flat region. When L(Bt; wt) 2 < µt + ctσt, ERM is performed and the second direction of the previous SAM update is reused to compose an approximate SAM direction. The procedure, called AE-Look SAM, is also shown in Algorithm 1. Algorithm 1 AE-SAM and AE-Look SAM . Require: training set D, stepsize η, radius ρ; λ1 and λ2 for gλ1,λ2(t); w0, µ 1 = 0, σ2 1 = e 10, and α for AE-Look SAM; 1: for t = 0, . . . , T 1 do 2: sample a mini-batch data Bt from D; 3: compute g = L(Bt; wt); 4: update µt by (4) and σ2 t by (5); 5: compute ct = gλ1,λ2(t); 6: if L(Bt; wt) 2 µt + ctσt then 7: gs = L(Bt; wt + ρ L(Bt; wt)); 8: if AE-Look SAM: decompose gs as gv = gs g gs 9: else: 10: if AE-SAM: gs = g; 11: if AE-Look SAM: gs = g + α g 12: end if 13: wt+1 = wt ηgs; 14: end for 15: return w T . 3.3 CONVERGENCE ANALYSIS In this section, we study the convergence of any algorithm A whose update in each iteration can be either SAM or ERM. Due to this mixing of SAM and ERM updates, analyzing its convergence is more challenging compared with that of SAM. The following assumptions on smoothness and bounded variance of stochastic gradients are standard in the literature on non-convex optimization (Ghadimi & Lan, 2013; Reddi et al., 2016) and SAM (Andriushchenko & Flammarion, 2022; Abbas et al., 2022; Qu et al., 2022). Assumption 3.1 (Smoothness). L(D; w) is β-smooth in w, i.e., L(D; w) L(D; v) β w v . Assumption 3.2 (Bounded variance of stochastic gradients). E(xi,yi) D ℓ(f(xi; w), yi) L(D; w) 2 σ2. Let ξt be an indicator of whether SAM or ERM is used at iteration t (i.e., ξt = 1 for SAM, and 0 for ERM). For example, ξt = I{w: L(Bt;w) 2 µt+ctσt}(wt) for the proposed AE-SAM, and ξt is sampled from a Bernoulli distribution for SS-SAM (Zhao et al., 2022b). Theorem 3.3. Let b be the mini-batch size. If stepsize η = 1 4β T and ρ = 1 T 1 4 , algorithm A satisfies min 0 t T 1 E L(D; wt) 2 32β (L(D; w0) EL(D; w T )) T (7 6ζ) + (1 + ζ + 5β2ζ)σ2 T (7 6ζ) , (6) where ζ = 1 T PT 1 t=0 ξt [0, 1] is the fraction of SAM updates, and the expectation is taken over the random training samples. All proofs are in Appendix A. Note that a larger ζ leads to a larger upper bound in (6). When ζ = 1, the above reduces to SAM (Corollary A.2 of Appendix A.1). Published as a conference paper at ICLR 2023 4 EXPERIMENTS In this section, we evaluate the proposed AE-SAM and AE-Look SAM on several standard benchmarks. As the SAM update doubles the computational overhead compared to the ERM update, the training speed is mainly determined by how often the SAM update is used. Hence, we evaluate efficiency by measuring the fraction of SAM updates used: %SAM 100 #{iterations using SAM}/T. The total number of iterations, T, is the same for all methods. 4.1 CIFAR-10 AND CIFAR-100 Setup. In this section, experiments are performed on the CIFAR-10 and CIFAR-100 datasets (Krizhevsky & Hinton, 2009) using four network architectures: Res Net-18 (He et al., 2016), Wide Res Net-28-10 (denoted WRN-28-10) (Zagoruyko & Komodakis, 2016), Pyramid Net-110 (Han et al., 2017), and Vi T-S16 (Dosovitskiy et al., 2021). Following the setup in (Liu et al., 2022; Foret et al., 2021; Zhao et al., 2022a), we use batch size 128, initial learning rate of 0.1, cosine learning rate schedule, SGD optimizer with momentum 0.9 and weight decay 0.0001. The number of training epochs is 300 for Pyramid Net-110, 1200 for Vi T-S16, and 200 for Res Net-18 and Wide Res Net-28-10. 10% of the training set is used as the validation set. As in Foret et al. (2021), we perform grid search for the radius ρ over {0.01, 0.02, 0.05, 0.1, 0.2, 0.5} using the validation set. Similarly, α is selected by grid search over {0.1, 0.3, 0.6, 0.9}. For the ct schedule gλ1,λ2(t), λ1 = 1 and λ2 = 1 for AE-SAM; λ1 = 0 and λ2 = 2 for AE-Look SAM. Baselines. The proposed AE-SAM and AE-Look SAM are compared with the following baselines: (i) ERM; (ii) SAM (Foret et al., 2021); and its more efficient variants including (iii) ESAM (Du et al., 2022a) which uses part of the weights to compute the perturbation and part of the samples to compute the SAM update direction. These two techniques can reduce the computational cost, but may not always accelerate SAM, particularly in parallel training (Li et al., 2020); (iv) SS-SAM (Zhao et al., 2022b), which randomly selects SAM or ERM according to a Bernoulli trial with success probability 0.5. This is the scheme with the best performance in (Zhao et al., 2022b); (v) Look SAM (Liu et al., 2022) which uses SAM at every k = 5 steps. The experiment is repeated five times with different random seeds. Results. Table 1 shows the testing accuracy and fraction of SAM updates (%SAM). Methods are grouped based on %SAM. As can be seen, AE-SAM has higher accuracy than SAM while using only 50% of SAM updates. SS-SAM and AE-SAM have comparable %SAM (about 50%), and AE-SAM achieves higher accuracy than SS-SAM (which is statistically significant based on the pairwise t-test at 95% significance level). Finally, Look SAM and AE-Look SAM have comparable %SAM (about 20%), and AE-Look SAM also has higher accuracy than Look SAM. These improvements confirm that the adaptive policy is better. 4.2 Image Net Setup. In this section, we perform experiments on the Image Net (Russakovsky et al., 2015), which contains 1000 classes and 1.28 million images. The Res Net-50 (He et al., 2016) is used. Following the setup in Du et al. (2022a), we train the network for 90 epochs using a SGD optimizer with momentum 0.9, weight decay 0.0001, initial learning rate 0.1, cosine learning rate schedule, and batch size 512. As in (Foret et al., 2021; Du et al., 2022a), ρ = 0.05. For the ct schedule gλ1,λ2(t), λ1 = 1 and λ2 = 1 for AE-SAM; λ1 = 0 and λ2 = 2 for AE-Look SAM. k = 5 is used for Look SAM. Experiments are repeated with three different random seeds. Results. Table 2 shows the testing accuracy and fraction of SAM updates. As can be seen, with only half of the iterations using SAM, AE-SAM achieves comparable performance as SAM. Compared with Look SAM, AE-Look SAM has better performance (which is also statistically significant), verifying the proposed adaptive policy is more effective than Look SAM s periodic policy. 4.3 ROBUSTNESS TO LABEL NOISE Setup. In this section, we study whether the more-efficient SAM variants will affect its robustness to training label noise. Following the setup in Foret et al. (2021), we conduct experiments on a corrupted Published as a conference paper at ICLR 2023 Table 1: Means and standard deviations of testing accuracy and fraction of SAM updates (%SAM) on CIFAR-10 and CIFAR-100. Methods are grouped based on %SAM. The highest accuracy in each group is underlined; while the highest accuracy for each network architecture (across all groups) is in bold. CIFAR-10 CIFAR-100 Accuracy %SAM Accuracy %SAM ERM 95.41 0.03 0.0 0.0 78.17 0.05 0.0 0.0 SAM (Foret et al., 2021) 96.52 0.12 100.0 0.0 80.17 0.15 100.0 0.0 ESAM (Du et al., 2022a) 96.56 0.08 100.0 0.0 80.41 0.10 100.0 0.0 SS-SAM (Zhao et al., 2022b) 96.40 0.16 50.0 0.0 80.10 0.16 50.0 0.0 AE-SAM 96.63 0.04 50.1 0.1 80.48 0.11 49.8 0.0 Look SAM (Liu et al., 2022) 96.32 0.12 20.0 0.0 79.89 0.29 20.0 0.0 AE-Look SAM 96.56 0.21 20.0 0.1 80.29 0.37 20.0 0.0 ERM 96.34 0.12 0.0 0.0 81.56 0.14 0.0 0.0 SAM (Foret et al., 2021) 97.27 0.11 100.0 0.0 83.42 0.05 100.0 0.0 ESAM (Du et al., 2022a) 97.29 0.11 100.0 0.0 84.51 0.02 100.0 0.0 SS-SAM (Zhao et al., 2022b) 97.09 0.11 50.0 0.0 82.89 0.02 50.0 0.0 AE-SAM 97.30 0.10 49.5 0.1 84.51 0.11 49.6 0.0 Look SAM (Liu et al., 2022) 97.02 0.12 20.0 0.0 83.70 0.12 20.0 0.0 AE-Look SAM 97.15 0.08 20.0 0.0 83.92 0.07 20.2 0.0 Pyramid Net-110 ERM 96.62 0.10 0.0 0.0 81.89 0.15 0.0 0.0 SAM (Foret et al., 2021) 97.30 0.10 100.0 0.0 84.46 0.05 100.0 0.0 ESAM (Du et al., 2022a) 97.81 0.01 100.0 0.0 85.56 0.05 100.0 0.0 SS-SAM (Zhao et al., 2022b) 97.22 0.10 50.0 0.0 84.90 0.05 50.0 0.0 AE-SAM 97.90 0.05 50.2 0.1 85.58 0.10 49.8 0.1 Look SAM (Liu et al., 2022) 97.10 0.11 20.0 0.0 84.01 0.06 20.0 0.0 AE-Look SAM 97.22 0.11 20.3 0.0 84.80 0.13 20.2 0.1 ERM 86.69 0.11 0.0 0.0 62.42 0.22 0.0 0.0 SAM (Foret et al., 2021) 87.37 0.09 100.0 0.0 63.23 0.25 100.0 0.0 ESAM (Du et al., 2022a) 84.27 0.11 100.0 0.0 62.11 0.15 100.0 0.0 SS-SAM (Zhao et al., 2022b) 87.38 0.14 50.0 0.0 63.18 0.19 50.0 0.0 AE-SAM 87.77 0.13 49.7 0.1 63.68 0.23 49.5 0.2 Look SAM (Liu et al., 2022) 87.12 0.20 20.0 0.0 63.52 0.19 20.0 0.0 AE-Look SAM 87.32 0.11 20.2 0.2 64.16 0.23 20.3 0.2 version of CIFAR-10, with some of its training labels randomly flipped (while its testing set is kept clean). The Res Net-18 and Res Net-32 networks are used. They are trained for 200 epochs using SGD with momentum 0.9, weight decay 0.0001, batch size 128, initial learning rate 0.1, and cosine learning rate schedule. For Look SAM, the SAM update is used every k = 2 steps.3 For AE-SAM and AE-Look SAM, we set λ1 = 1 and λ2 = 1 in their ct schedules gλ1,λ2(t), such that their fractions of SAM updates (approximately 50%) are comparable with SS-SAM and Look SAM. Experiments are repeated with five different random seeds. Results. Table 3 shows the testing accuracy and fraction of SAM updates. As can be seen, AELook SAM achieves comparable performance with SAM but is faster as only half of the iterations use the SAM update. Compared with ESAM, SS-SAM, and Look SAM, AE-Look SAM performs better. The improvement is particularly noticeable at the higher noise levels (e.g., 80%). 3The performance of Look SAM can be sensitive to the value of k. Table 4 of Appendix B.2 shows that using k = 2 leads to the best performance in this experiment. Published as a conference paper at ICLR 2023 Table 2: Means and standard deviations of testing accuracy and fraction of SAM updates (%SAM) on Image Net using Res Net-50. Methods are grouped based on %SAM. The highest accuracy in each group is underlined; while the highest across all groups is in bold. Accuracy %SAM ERM 77.11 0.14 0.0 0.0 SAM (Foret et al., 2021) 77.47 0.12 100.0 0.0 ESAM (Du et al., 2022a) 77.25 0.75 100.0 0.0 SS-SAM (Zhao et al., 2022b) 77.38 0.06 50.0 0.0 AE-SAM 77.43 0.06 49.4 0.0 Look SAM (Liu et al., 2022) 77.13 0.09 20.0 0.0 AE-Look SAM 77.29 0.08 20.3 0.0 Table 3: Testing accuracy and fraction of SAM updates on CIFAR-10 with different levels of label noise. The best accuracy is in bold and the second best is underlined. noise = 20% noise = 40% noise = 60% noise = 80% accuracy %SAM accuracy %SAM accuracy %SAM accuracy %SAM ERM 87.92 0.0 70.82 0.0 49.61 0.0 28.23 0.0 SAM (Foret et al., 2021) 94.80 100.0 91.50 100.0 88.15 100.0 77.40 100.0 ESAM (Du et al., 2022a) 94.19 100.0 91.46 100.0 81.30 100.0 15.00 100.0 SS-SAM (Zhao et al., 2022b) 90.62 50.0 77.84 50.0 61.18 50.0 47.32 50.0 Look SAM (Liu et al., 2022) 92.72 50.0 88.04 50.0 72.26 50.0 69.72 50.0 AE-SAM 92.84 50.0 84.17 50.0 73.54 49.9 65.00 50.0 AE-Look SAM 94.34 49.9 91.58 50.0 87.85 50.0 76.90 50.0 ERM 87.43 0.0 70.82 0.0 46.26 0.0 29.00 0.0 SAM (Foret et al., 2021) 95.08 100.0 91.01 100.0 88.90 100.0 77.32 100.0 ESAM (Du et al., 2022a) 93.42 100.0 91.63 100.0 82.73 100.0 10.09 100.0 SS-SAM (Zhao et al., 2022b) 89.63 50.0 74.17 50.0 58.40 50.0 59.53 50.0 Look SAM (Liu et al., 2022) 92.49 50.0 86.56 50.0 63.35 50.0 68.01 50.0 AE-SAM 92.87 50.0 82.85 50.0 71.50 50.0 65.43 50.3 AE-Look SAM 94.70 50.0 91.80 50.0 88.22 50.0 77.03 49.8 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (a) Training accuracy. 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (b) Testing accuracy. Figure 4: Accuracies with number of training epochs on CIFAR-10 (with 80% noise labels) using Res Net-18. Best viewed in color. Figure 4 shows the training and testing accuracies with number of epochs at a noise level of 80% using Res Net-184. As can be seen, SAM is robust to the label noise, while ERM and SS-SAM heavily suffer from overfitting. AE-SAM and Look SAM can alleviate the overfitting problem to a certain extent. AE-Look SAM, by combining the adaptive policy with Look SAM, achieves the same high level of robustness as SAM. 4Results for other noise levels and Res Net-32 are shown in Figures 10 and 11 of Appendix B.3, respectively. Published as a conference paper at ICLR 2023 (a) CIFAR-10. (b) CIFAR-100. Figure 5: Effects of λ1 and λ2 on fraction of SAM updates using Res Net-18. Best viewed in color. testing accuracy (a) CIFAR-10. testing accuracy (b) CIFAR-100. Figure 6: Effects of λ1 and λ2 on testing accuracy using Res Net-18. Best viewed in color. 4.4 EFFECTS OF λ1 AND λ2 In this experiment, we study the effects of λ1 and λ2 on AE-SAM. We use the same setup as in Section 4.1, where λ1 and λ2 (with λ1 λ2) are chosen from {0, 1, 2}. Results on AE-Look SAM using the label noise setup in Section 4.3 are shown in Appendix B.4. Figure 5 shows the effect on the fraction of SAM updates. For a fixed λ2, increasing λ1 increases the threshold ct, and the condition L(Bt; wt) 2 µt + ctσt becomes more difficult to satisfy. Thus, as can be seen, the fraction of SAM updates is reduced. The same applies when λ2 increases. A similar trend is also observed on the testing accuracy (Figure 6). 4.5 CONVERGENCE In this experiment, we study whether wt s (where t is the number of epochs) obtained from AE-SAM can reach critical points of L(D; w), as suggested in Theorem 3.3. Figure 7 shows L(D; wt) 2 w.r.t. t for the experiment in Section 4.1. As can be seen, in all settings, L(D; wt) 2 converges to 0. In Appendix B.5, we also verify the convergence of AE-SAM s training loss on CIFAR-10 and CIFAR-100 (Figure 14), and that AE-SAM and SS-SAM have comparable convergence speeds (Figure 15), which agrees with Theorem 3.3 as both have comparable fractions of SAM updates (Table 1). 0 50 100 150 200 epoch CIFAR-10 CIFAR-100 (a) Res Net-18. 0 50 100 150 200 epoch CIFAR-10 CIFAR-100 (b) WRN-28-10. 0 50 100 150 200 250 300 epoch CIFAR-10 CIFAR-100 (c) Pyramid Net-110. Figure 7: Squared gradient norms of AE-SAM with number of epochs. Best viewed in color. 5 CONCLUSION In this paper, we proposed an adaptive policy to employ SAM based on the loss landscape geometry. Using the policy, we proposed an efficient algorithm (called AE-SAM) to reduce the fraction of SAM updates during training. We theoretically and empirically analyzed the convergence of AE-SAM. Experimental results on a number of datasets and network architectures verify the efficiency and effectiveness of the adaptive policy. Moreover, the proposed policy is general and can be combined with other SAM variants, as demonstrated by the success of AE-Look SAM. Published as a conference paper at ICLR 2023 ACKNOWLEDGMENTS This work was supported by NSFC key grant 62136005, NSFC general grant 62076118, and Shenzhen fundamental research program JCYJ20210324105000003. This research was supported in part by the Research Grants Council of the Hong Kong Special Administrative Region (Grant 16200021). Momin Abbas, Quan Xiao, Lisha Chen, Pin-Yu Chen, and Tianyi Chen. Sharp-MAML: Sharpnessaware model-agnostic meta learning. In International Conference on Machine Learning, 2022. Maksym Andriushchenko and Nicolas Flammarion. Towards understanding sharpness-aware minimization. In International Conference on Machine Learning, 2022. Devansh Bisla, Jing Wang, and Anna Choromanska. Low-pass filtering SGD for recovering flat optima in the deep learning optimization landscape. In International Conference on Artificial Intelligence and Statistics, 2022. L eon Bottou, Frank E Curtis, and Jorge Nocedal. Optimization methods for large-scale machine learning. SIAM Review, 2018. Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, and Sungrae Park. SWAD: Domain generalization by seeking flat minima. In Neural Information Processing Systems, 2021. Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann Le Cun, Carlo Baldassi, Christian Borgs, Jennifer Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-SGD: Biasing gradient descent into wide valleys. In International Conference on Learning Representations, 2017. Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. In International Conference on Machine Learning, 2017. Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. Jiawei Du, Hanshu Yan, Jiashi Feng, Joey Tianyi Zhou, Liangli Zhen, Rick Siow Mong Goh, and Vincent Tan. Efficient sharpness-aware minimization for improved training of neural networks. In International Conference on Learning Representations, 2022a. Jiawei Du, Daquan Zhou, Jiashi Feng, Vincent YF Tan, and Joey Tianyi Zhou. Sharpness-aware training for free. In Neural Information Processing Systems, 2022b. Gintare Karolina Dziugaite and Daniel M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In Uncertainty in Artificial Intelligence, 2017. Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. Saeed Ghadimi and Guanghui Lan. Stochastic firstand zeroth-order methods for nonconvex stochastic programming. SIAM Journal on Optimization, 2013. Dongyoon Han, Jiwhan Kim, and Junmo Kim. Deep pyramidal residual networks. In IEEE Conference on Computer Vision and Pattern Recognition, 2017. Haowei He, Gao Huang, and Yang Yuan. Asymmetric valleys: Beyond sharp and flat local minima. In Neural Information Processing Systems, 2019. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In IEEE Conference on Computer Vision and Pattern Recognition, 2016. Published as a conference paper at ICLR 2023 Sepp Hochreiter and J urgen Schmidhuber. Simplifying neural nets by discovering flat minima. In Neural Information Processing Systems, 1994. Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. In Uncertainty in Artificial Intelligence, 2018. Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. In International Conference on Learning Representations, 2020. Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. In International Conference on Learning Representations, 2017. Mohammad Khan, Didrik Nielsen, Voot Tangkaratt, Wu Lin, Yarin Gal, and Akash Srivastava. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In International Conference on Machine Learning, 2018. Minyoung Kim, Da Li, Shell X Hu, and Timothy Hospedales. Fisher SAM: Information geometry and sharpness aware minimisation. In International Conference on Machine Learning, 2022. Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015. Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical report, 2009. Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. ASAM: Adaptive sharpnessaware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning, 2021. Shen Li, Yanli Zhao, Rohan Varma, Omkar Salpekar, Pieter Noordhuis, Teng Li, Adam Paszke, Jeff Smith, Brian Vaughan, and Pritam Damania. Py Torch distributed: experiences on accelerating data parallel training. In Proceedings of the VLDB Endowment, 2020. Yong Liu, Siqi Mai, Xiangning Chen, Cho-Jui Hsieh, and Yang You. Towards efficient and scalable sharpness-aware minimization. In IEEE Conference on Computer Vision and Pattern Recognition, 2022. David A Mc Allester. PAC-Bayesian model averaging. In Annual Conference on Computational Learning Theory, 1999. Behnam Neyshabur, Srinadh Bhojanapalli, David Mc Allester, and Nati Srebro. Exploring generalization in deep learning. In Neural Information Processing Systems, 2017. Antonio Orvieto, Hans Kersting, Frank Proske, Francis Bach, and Aurelien Lucchi. Anticorrelated noise injection for improved generalization. In International Conference on Machine Learning, 2022. Henning Petzka, Michael Kamp, Linara Adilova, Cristian Sminchisescu, and Mario Boley. Relative flatness and generalization. In Neural Information Processing Systems, 2021. Zhe Qu, Xingyu Li, Rui Duan, Yao Liu, Bo Tang, and Zhuo Lu. Generalized federated learning via sharpness aware minimization. In International Conference on Machine Learning, 2022. Sashank J Reddi, Ahmed Hefny, Suvrit Sra, Barnabas Poczos, and Alex Smola. Stochastic variance reduction for nonconvex optimization. In International Conference on Machine Learning, 2016. Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. Image Net large scale visual recognition challenge. International Journal of Computer Vision, 2015. Tijmen Tieleman and Geoffrey Hinton. RMSProp: Neural networks for machine learning. Lecture 6.5, 2012. Published as a conference paper at ICLR 2023 Martin B Wilk and Ram Gnanadesikan. Probability plotting methods for the analysis for the analysis of data. Biometrika, 1968. Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In British Machine Vision Conference, 2016. Matthew D Zeiler. Ada Delta: an adaptive learning rate method. Preprint ar Xiv:1212.5701, 2012. Yang Zhao, Hao Zhang, and Xiuyuan Hu. Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, 2022a. Yang Zhao, Hao Zhang, and Xiuyuan Hu. SS-SAM: Stochastic scheduled sharpness-aware minimization for efficiently training deep neural networks. Preprint ar Xiv:2203.09962, 2022b. Mo Zhou, Tianyi Liu, Yan Li, Dachao Lin, Enlu Zhou, and Tuo Zhao. Toward understanding the importance of noise in training neural networks. In International Conference on Machine Learning, 2019. Zhanxing Zhu, Jingfeng Wu, Bing Yu, Lei Wu, and Jinwen Ma. The anisotropic noise in stochastic gradient descent: Its behavior of escaping from sharp minima and regularization effects. In International Conference on Machine Learning, 2019. Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha C Dvornek, James S. Duncan, and Ting Liu. Surrogate gap minimization improves sharpness-aware training. In International Conference on Learning Representations, 2022. A.1 PROOF OF THEOREM 3.3 Theorem 3.3. Let b be the mini-batch size. If η = 1 4β T and ρ = 1/T 1 4 , algorithm A satisfies min 0 t T 1 E L(D; wt) 2 32β (L(D; w0) EL(D; w T )) T (7 6ζ) + (1 + ζ + 5β2ζ)σ2 T (7 6ζ) , (7) where ζ = 1 T PT 1 t=0 ξt [0, 1]. Lemma A.1 (Andriushchenko & Flammarion (2022)). Under Assumptions 3.1 and 3.2 for all t and ρ > 0, we have E L(Bt; w + ρ L(Bt; w)) L(D; w) 1 2 ρβ L(D; w) 2 β2ρ2σ2 Proof. Let gt 1 b P (xi,yi) Bt ℓ(f(xi; wt), yi), ht 1 b P (xi,yi) Bt ℓ(f(xi; wt + ρgt), yi), and ˆgt L(D; wt). By Taylor expansion and L(D; w) is β-smooth, we have L(D; wt) + ˆg t (wt+1 wt) + β 2 wt+1 wt 2 L(D; wt) ηˆg t ((1 ξt)gt + ξtht) + βη2 2 (1 ξt)gt + ξtht 2 =L(D; wt) η(1 ξt)ˆg t gt ηξtˆg t ht+ βη2 (1 ξt) gt 2+ξt ht 2+2ξt(1 ξt)g t ht | {z } =0 =L(D; wt) η(1 ξt)ˆg t gt ηξtˆg t ht + βη2 2 (1 ξt) gt 2 + ξt ht 2 , (10) Published as a conference paper at ICLR 2023 where we have used ξt(1 ξt) = 0 as ξt {0, 1}, ξ2 t = ξt, and (1 ξt)2 = 1 ξt to obtain (9). Taking expectation w.r.t. wt on both sides of (10), we have EL(D; wt+1) EL(D; wt) η(1 ξt)E ˆgt 2 ηξt Eˆg t ht+ βη2(1 ξt) 2 E gt 2+ βη2ξt Claim 1: E gt 2 =E gt ˆgt 2+E ˆgt 2 = σ2 b +E ˆgt 2, which follows from Assumption 3.2. Claim 2: E ht 2 2(1 + ρ2β2) σ2 b (1 2ρ2β2)E ˆgt 2 + 2Eˆg t ht, which is derived as follows: E ht 2 = E ht ˆgt 2 E ˆgt 2 + 2Eˆg t ht = 2E ht gt 2 + 2E gt ˆgt 2 E ˆgt 2 + 2Eˆg t ht 2ρ2β2E gt 2 + 2σ2 b E ˆgt 2 + 2Eˆg t ht (12) b + E ˆgt 2 + 2σ2 b E ˆgt 2 + 2Eˆg t ht (13) = 2(1 + ρ2β2)σ2 b (1 2ρ2β2)E ˆgt 2 + 2Eˆg t ht, (14) where (12) follows from ht gt ρβ gt and Assumption 3.2, (13) follows from Claim 1. Substituting Claims 1 and 2 into (11), we obtain EL(D; wt+1) EL(D; wt) η (1 ξt) E ˆgt 2 ηξt Eˆg t ht + βη2(1 ξt) b + E ˆgt 2 2(1 + ρ2β2)σ2 b (1 2ρ2β2)E ˆgt 2 + 2Eˆg t ht = EL(D; wt) η 1 ξt βη(1 ξt) 2 + βηξt(1 2ρ2β2) E ˆgt 2 ηξt (1 ηβ) Eˆg t ht + βη2(1 ξt) 2 + βη2ξt(1 + ρ2β2) σ2 EL(D; wt) η 1 ξt βη(1 ξt) 2 + βηξt(1 2ρ2β2) 2 + ξt (1 ηβ) (1 2 ρβ) E ˆgt 2 + βη2(1 ξt) 2 + βη2ξt(1 + ρ2β2) + ηξt (1 ηβ) β2ρ2 EL(D; wt) η 1 (1 + βη 2ρβ)ξt + η + ξt(η + 2ηρ2β2 + βρ2 ηβ2ρ2) ηβσ2 where (15) follows from Claims 1 and 2, (16) follows from Lemma A.1 and 1 ηβ > 0. As η < 1 4β , we have 1 + βη 2ρβ 3/2 and βη < 1/4, thus, 1 (1 + βη 2ρβ) ξt Published as a conference paper at ICLR 2023 Summing over t on both sides of (17) and rearranging, we obtain min 0 t T 1 E ˆgt 2 L(D; w0) EL(D; w T ) η PT 1 t=0 1 (1 + βη 2ρβ) ξt + PT 1 t=0 η + ξt(η + ηρ2β2 + βρ2) PT 1 t=0 1 (1 + βη 2ρβ) ξt = L(D; w0) EL(D; w T ) 2 ) + T(η + ηκζ + βρ2ζ)βσ2 = L(D; w0) EL(D; w T ) 2 ) + (1 + κζ + 4β2ζ)ηβσ2 = L(D; w0) EL(D; w T ) 2 ) + (1 + κζ + 4β2ζ)σ2 32β (L(D; w0) EL(D; w T )) T (7 6ζ) + (1 + ζ + 5β2ζ)σ2 T (7 6ζ) , (20) where γ = 1 + βη 2ρβ 3/2, κ = 1 + ρ2β2, ρ2 = 1/ T, and ζ = 1 T PT 1 t=0 ξt [0, 1]. We thus finish the proof. Corollary A.2. Let b be the mini-batch size. If η = 1 4β T and ρ = 1/T 1 4 , SAM (Foret et al., 2021) satisfies min 0 t T 1 E L(D; wt) 2 32β (L(D; w0) EL(D; w T )) T + (2 + 5β2)σ2 Corollary A.3. Let b be the mini-batch size. If η = T and ρ = 1/T 1 4 , algorithm A satisfies min 0 t T 1 E L(D; wt) 2 32β(L(D; w0) EL(D; w T )) Tb(7 6ζ) + (1 + ζ + 5β2ζ)σ2 Tb(7 6ζ) , (22) where ζ = 1 T PT 1 t=0 ξt [0, 1]. Proof. It follows from (18) that min 0 t T 1 E ˆgt 2 L(D; w0) EL(D; w T ) 2 ) + ηβ(1 + κζ + 4β2ζ)σ2 4β(L(D; w0) EL(D; w T )) 4ζ) + (1 + ζ + 5β2ζ)σ2 = 32β(L(D; w0) EL(D; w T )) Tb(7 6ζ) + (1 + ζ + 5β2ζ)σ2 Tb(7 6ζ) . (25) A.2 CONVERGENCE OF FULL-BATCH GRADIENT DESCENT FOR AE-SAM Theorem A.4. Under Assumption 3.1, with full-batch gradient descent, if ρ < 1 2β and η < 1 β , algorithm A satisfies min 0 t T 1 L(D; wt) 2 L(D; w0) L(D; w T ) 2 βρζ , (26) where ζ = 1 T PT 1 t=0 ξt [0, 1]. Published as a conference paper at ICLR 2023 Lemma A.5 (Lemma 7 in Andriushchenko & Flammarion (2022)). Let L(D; w) be a β-smooth function. For any ρ > 0, we have L(D; w) L(D; w + ρ L(D; w)) (1 ρβ) L(D; w) 2. (27) Proof of Theorem A.4. Let gt L(D; wt) and ht L(D; wt + ρ L(D; wt)) be the update direction of ERM and SAM, respectively. By Taylor expansion and L(D; w) is β-smooth, we have L(D; wt+1) L(D; wt) + g t (wt+1 wt) + β 2 wt+1 wt 2 L(D; wt) ηg t ((1 ξt)gt + ξtht) + βη2 2 (1 ξt)gt + ξtht 2 =L(D; wt) η(1 ξt) gt 2 ηξtg t ht+ βη2 (1 ξt) gt 2+ξt ht 2+2ξt(1 ξt)g t ht | {z } =0 = L(D; wt) η 1 ξt βη(1 ξt) gt 2 + βη2ξt 2 ht 2 ηξtg t ht, (29) where we have used ξt(1 ξt) = 0 as ξt {0, 1}, ξ2 t = ξt, and (1 ξt)2 = 1 ξt to obtain (28). As ht 2 = ht gt 2 gt 2 + 2g t ht, it follows from (29) that L(D; wt+1) =L(D; wt) η 1 ξt βη(1 ξt) gt 2+ βη2ξt 2 ht gt 2 gt 2 + 2g t ht ηξtg t ht L(D; wt) η 1 ξt βη(1 ξt) gt 2 + βη2ξt 2 ht gt 2 η(1 βη)ξtg t ht L(D; wt) η 1 ξt βη(1 ξt) gt 2 + β3η2ρ2ξt 2 gt 2 η(1 βη)ξtg t ht (30) =L(D; wt) η 1 ξt βη(1 ξt) 2 + β3ηρ2ξt 2 + (1 βη)(1 βρ)ξt =L(D; wt) η 1 βη(1 ξt) 2 + β3ηξtρ2 2 βηξt βρξt + β2ηρξt L(D; wt) η 1 βη where we have used ht gt 2 = L(D; wt + ρ L(D; wt)) L(D; wt) 2 β2ρ2 L(D; wt) 2 = β2ρ2 gt 2 to obtain (30), and Lemma A.5 to obtain (31). Summing over t from t = 0 to T 1 on both sides of (32) and rearranging, we have gt 2 L(D; w0) L(D; w T ). (33) As ρ < 1 2β and η < 1 β , it follows that 1 βη 2 βρξt > 0 for all t. Thus, (33) implies min 0 t T 1 gt 2 L(D; w0) L(D; w T ) PT 1 t=0 η 1 βη 2 ξtβρ = L(D; w0) L(D; w T ) 2 βρζ , (34) where ζ = 1 T PT 1 t=0 ξt [0, 1] and we finish the proof. B ADDITIONAL EXPERIMENTAL RESULTS B.1 DISTRIBUTION OF STOCHASTIC GRADIENT NORMS Figure 8 shows the distributions of stochastic gradient norms for Res Net-18, WRN-28-10 and Pyramid Net-110 on CIFAR-10 and CIFAR-100. As can be seen, the distribution follows a Bell curve in all settings. Figure 9 shows the Q-Q plots. We can see that the curves are close to the lines. Published as a conference paper at ICLR 2023 0 5 10 15 20 25 ( t; wt) 2 epoch = 180 epoch = 120 (a) Res Net-18. 0 10 20 30 40 50 60 ( t; wt) 2 epoch = 180 epoch = 120 (b) WRN-28-10. 0 10 20 30 40 ( t; wt) 2 epoch = 240 epoch = 160 (c) Pyramid Net-110. 0 10 20 30 40 50 ( t; wt) 2 epoch = 180 epoch = 120 (d) Res Net-18. 0 20 40 60 80 100 120 140 ( t; wt) 2 epoch = 180 epoch = 120 epoch = 60 (e) WRN-28-10. 0 20 40 60 80 ( t; wt) 2 epoch = 240 epoch = 160 (f) Pyramid Net-110. Figure 8: Distributions of stochastic gradient norms on CIFAR-10 (top) and CIFAR-100 (bottom). Best viewed in color. 2 0 2 theoretical quantiles sample quantiles epoch=60 epoch=120 epoch=180 (a) Res Net-18. 2 0 2 theoretical quantiles sample quantiles epoch=60 epoch=120 epoch=180 (b) WRN-28-10. 2 0 2 theoretical quantiles sample quantiles epoch=80 epoch=160 epoch=240 (c) Pyramid Net-110. 2 0 2 theoretical quantiles sample quantiles epoch=180 epoch=120 epoch=60 (d) Res Net-18. 2 0 2 theoretical quantiles sample quantiles epoch=180 epoch=120 epoch=60 (e) WRN-28-10. 2 0 2 theoretical quantiles sample quantiles epoch=240 epoch=160 epoch=80 (f) Pyramid Net-110. Figure 9: Q-Q plots of stochastic gradient norms on CIFAR-10 (top) and CIFAR-100 (bottom). Best viewed in color. B.2 EFFECT OF k ON LOOKSAM In this experiment, we demonstrate that Look SAM is sensitive to the choice of k. Table 4 shows the testing accuracy and fraction of SAM updates when using Look SAM on noisy CIFAR-10, with k {2, 3, 4, 5} and the Res Net-18 model. As can be seen, k = 2 yields much better performance than k {3, 4, 5}, particularly at higher noise levels (e.g., 80%). Published as a conference paper at ICLR 2023 Table 4: Effects of k in Look SAM on CIFAR-10 with different levels of label noise using Res Net-18. noise = 20% noise = 40% noise = 60% noise = 80% k accuracy %SAM accuracy %SAM accuracy %SAM accuracy %SAM 2 92.72 50.0 88.04 50.0 72.26 50.0 69.72 50.0 3 89.07 33.3 75.38 33.3 63.79 33.3 53.87 33.3 4 89.00 25.0 74.12 25.0 58.17 25.0 52.28 25.0 5 88.57 20.0 73.90 20.0 56.80 20.0 51.82 20.0 B.3 MORE RESULTS ON ROBUSTNESS TO LABEL NOISE Figure 10 (resp. 11) shows the curves of accuracies at noise levels of 20%, 40%, 60%, and 80% with Res Net-18 (resp. Res Net-32). As can be seen, in all settings, AE-Look SAM is as robust to label noise as SAM. 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (a) 20% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (b) 20% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (c) 40% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (d) 40% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (e) 60% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (f) 60% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (g) 80% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (h) 80% (Testing). Figure 10: Accuracies with number of epochs on CIFAR-10 with 20%, 40%, 60%, and 80% noise level using Res Net-18. Best viewed in color. 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (a) 20% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (b) 20% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (c) 40% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (d) 40% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (e) 60% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (f) 60% (Testing). 0 50 100 150 200 epoch training accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (g) 80% (Training). 0 50 100 150 200 epoch testing accuracy ERM SAM ESAM SS-SAM Look SAM AE-SAM AE-Look SAM (h) 80% (Testing). Figure 11: Accuracies with number of epochs on CIFAR-10 with 20%, 40%, 60%, and 80% noise level using Res Net-32. Best viewed in color. Published as a conference paper at ICLR 2023 B.4 EFFECTS OF λ1 AND λ2 ON AE-LOOKSAM In this experiment, we study the effects of λ1 and λ2 on AE-Look SAM. Experiment is performed on CIFAR-10 with label noise (80% noisy labels), using the same setup as in Section 4.3. Figure 12 shows the effects of λ1 and λ2 on the fraction of SAM updates. Again, as in Section 4.4, for a fixed λ2, increasing λ1 always reduces the fraction of SAM updates. Figure 13 shows the effects of λ1 and λ2 on the testing accuracy of AE-SAM. As can be seen, the observations are similar to those in Section 4.4. (a) Res Net-18. (b) Res Net-32. Figure 12: Effects of λ1 and λ2 on fraction of SAM updates on CIFAR-10 (with 80% noisy labels). Best viewed in color. testing accuracy (a) Res Net-18. testing accuracy (b) Res Net-32. Figure 13: Effects of λ1 and λ2 on testing accuracy of CIFAR-10 (with 80% noisy labels). Note that the curves for λ2 { 2, 1} overlap completely with that of λ2 = 1. Best viewed in color. B.5 ADDITIONAL CONVERGENCE RESULTS ON CIFAR-10 AND CIFAR-100 Figure 14 shows convergence of AE-SAM s training loss on the CIFAR-10 and CIFAR-100 datasets. As can be seen, AE-SAM achieves convergence with various network architectures. Figure 15 shows the training losses w.r.t. the number of epochs for AE-SAM and SS-SAM. As can be seen, AE-SAM and SS-SAM converge with comparable speeds, which agrees with Theorem 3.3 as both of them have comparable fractions of SAM updates (Table 1). 0 50 100 150 200 epoch CIFAR-10 CIFAR-100 (a) Res Net-18. 0 50 100 150 200 epoch CIFAR-10 CIFAR-100 (b) WRN-28-10. 0 50 100 150 200 250 300 epoch CIFAR-10 CIFAR-100 (c) Pyramid Net-110. Figure 14: Training loss of AE-SAM with number of epochs on CIFAR-10 and CIFAR-100. Best viewed in color. 0 50 100 150 200 epoch AE-SAM SS-SAM (a) Res Net-18. 0 50 100 150 200 epoch AE-SAM SS-SAM (b) WRN-28-10. 0 50 100 150 200 250 300 epoch AE-SAM SS-SAM (c) Pyramid Net-110. Figure 15: Training losses of AE-SAM and SS-SAM with number of epochs on CIFAR-10. Note that the two curves almost completely overlap. Best viewed in color.