# how_sharpnessaware_minimization_minimizes_sharpness__62e02b14.pdf Published as a conference paper at ICLR 2023 HOW DOES SHARPNESS-AWARE MINIMIZATION MINIMIZE SHARPNESS? Kaiyue Wen Institute for Interdisciplinary Information Sciences Tsinghua University wenky20@mails.tsinghua.edu.cn Tengyu Ma, Zhiyuan Li Computer Science Department Stanford University {tengyuma,zhiyuanli}@stanford.edu Sharpness-Aware Minimization (SAM) is a highly effective regularization technique for improving the generalization of deep neural networks for various settings. However, the underlying working of SAM remains elusive because of various intriguing approximations in the theoretical characterizations. SAM intends to penalize a notion of sharpness of the model but implements a computationally efficient variant; moreover, a third notion of sharpness was used for proving generalization guarantees. The subtle differences in these notions of sharpness can indeed lead to significantly different empirical results. This paper rigorously nails down the exact sharpness notion that SAM regularizes and clarifies the underlying mechanism. We also show that the two steps of approximations in the original motivation of SAM individually lead to inaccurate local conclusions, but their combination accidentally reveals the correct effect, when full-batch gradients are applied. Furthermore, we also prove that the stochastic version of SAM in fact regularizes the third notion of sharpness mentioned above, which is most likely to be the preferred notion for practical performance. The key mechanism behind this intriguing phenomenon is the alignment between the gradient and the top eigenvector of Hessian when SAM is applied. 1 INTRODUCTION Modern deep nets are often overparametrized and have the capacity to fit even randomly labeled data (Zhang et al., 2016). Thus, a small training loss does not necessarily imply good generalization. Yet, standard gradient-based training algorithms such as SGD are able to find generalizable models. Recent empirical and theoretical studies suggest that generalization is well-correlated with the sharpness of the loss landscape at the learned parameter (Keskar et al., 2016; Dinh et al., 2017; Dziugaite et al., 2017; Neyshabur et al., 2017; Jiang et al., 2019). Partly motivated by these studies, Foret et al. (2021); Wu et al. (2020); Zheng et al. (2021); Norton et al. (2021) propose to penalize the sharpness of the landscape to improve the generalization. We refer this method to Sharpness-Aware Minimization (SAM) and focus on the version of Foret et al. (2021). Despite its empirical success, the underlying working of SAM remains elusive because of the various intriguing approximations made in its derivation and analysis. There are three different notions of sharpness involved SAM intends to optimize the first notion, the sharpness along the worst direction, but actually implements a computationally efficient notion, the sharpness along the direction of the gradient. But in the analysis of generalization, a third notion of sharpness is actually used to prove generalization guarantees, which admits the first notion as an upper bound. The subtle difference between the three notions can lead to very different biases (see Figure 1 for demonstration). More concretely, let L be the training loss, x be the parameter and ρ be the perturbation radius, a hyperparameter requiring tuning. The first notion corresponds to the following optimization problem (1), where we call RMax ρ (x) = LMax ρ (x) L(x) the worst-direction sharpness at x. SAM intends to minimize the original training loss plus the worst-direction sharpness at x. min x LMax ρ (x), where LMax ρ (x) = max v 2 1 L(x + ρv) . (1) However, even evaluating LMax ρ (x) is computationally expensive, not to mention optimization. Thus Foret et al. (2021); Zheng et al. (2021) have introduced a second notion of sharpness, which approximates the worstcase direction in (1) by the direction of gradient, as defined below in (2). We call RAsc ρ (x) = LAsc ρ (x) L(x) the ascent-direction sharpness at x. min x LAsc ρ (x), where LAsc ρ (x) = L (x + ρ L(x)/ L(x) 2) . (2) Published as a conference paper at ICLR 2023 Type of Sharpness-Aware Loss Notation Definition Biases (among minimizers) Worst-direction LMax ρ max v 2 1 L(x + ρv) minx λ1( 2L(x)) (Thm G.3) Ascent-direction LAsc ρ L x + ρ L(x) L(x) 2 minx λmin( 2L(x)) (Thm G.4) Average-direction LAvg ρ Eg N(0,I)L(x + ρ g g 2 ) minx Tr( 2L(x)) (Thm G.5) Table 1: Definitions and biases of different notions of sharpness-aware loss. The corresponding sharpness is defined as the difference between sharpness-aware loss and the original loss. Here λ1 denotes the largest eigenvalue and λmin denotes the smallest non-zero eigenvalue. For further acceleration, Foret et al. (2021); Zheng et al. (2021) omit the gradient through other occurrence of x and approximate the gradient of ascent-direction sharpness by gradient taken after one-step ascent, i.e., LAsc ρ (x) L (x + ρ L(x)/ L(x) 2) and derive the update rule of SAM, where η is the learning rate. Sharpness-Aware Minimization (SAM): x(t + 1) = x(t) η L (x + ρ L(x)/ L(x) 2) . (3) Intriguingly, the generalization bound of SAM upperbounds the generalization error by the third notion of sharpness, called average-direction sharpness, RAvg ρ (x) and defined formally below. RAvg ρ (x) = LAvg ρ (x) L(x), where LAvg ρ (x) = Eg N(0,I)L (x + ρg/ g 2) . (4) The worst-case sharpness is an upper bound of the average case sharpness and thus it is a looser bound for generalization error. In other words, according to the generalization theory in Foret et al. (2021); Wu et al. (2020) in fact motivates us to directly minimize the average case sharpness (as opposed to the worst-case sharpness that SAM intends to optimize). In this paper, we analyze the biases introduced by penalizing these various notions of sharpness as well as the bias of SAM (Equation 3). Our analysis for SAM is performed for small perturbation radius ρ and learning rate η under the setting where the minimizers of loss form a manifold following the setup of Fehrman et al. (2020); Li et al. (2021). In particular, we make the following theoretical contributions. 1. We prove that full-batch SAM indeed minimizes worst-direction sharpness. (Theorem 4.5) 2. Surprisingly, when batch size is 1, SAM minimizes average-direction sharpness. (Theorem 5.4) 3. We provide a characterization (Theorems 4.2 and 5.3) of what a few sharpness regularizers bias towards among the minimizers (including all the three notions of the sharpness in Table 1), when the perturbation radius ρ goes to zero. Surprisingly, both heuristic approximations made for SAM lead to inaccurate conclusions: (1) Minimizing worst-direction sharpness and ascent-direction sharpness induce different biases among minimizers, and (2) SAM doesn t minimize ascent-direction sharpness. The key mechanism behind this bias of SAM is the alignment between gradient and the top eigenspace of Hessian of the original loss in the latter phase of training the angle between them decreases gradually to the level of O(ρ). It turns out that the worst-direction sharpness starts to decrease once such alignment is established (see Section 4.3). Interestingly, such an alignment is not implied by the minimization problem (2), but rather, it is an implicit property of the specific update rule of SAM. Interestingly, such an alignment property holds for SAM with full batch and SAM with batch size one, but does not necessarily hold for the mini-batch case. 2 RELATED WORKS Sharpness and Generalization. The study on the connection between sharpness and generalization can be traced back to Hochreiter et al. (1997). Keskar et al. (2016) observe a positive correlation between the batch size, the generalization error, and the sharpness of the loss landscape when changing the batch size. Jastrzebski et al. (2017) extend this by finding a correlation between the sharpness and the ratio between learning rate to batch size. Dinh et al. (2017) show that one can easily construct networks with good generalization but with arbitrary large sharpness by reparametrization. Dziugaite et al. (2017); Neyshabur et al. (2017); Wei et al. (2019a;b) give theoretical guarantees on the generalization error using sharpness-related measures. Jiang et al. (2019) perform a large-scale empirical study on various generalization measures and show that sharpness-based measures have the highest correlation with generalization. Background on Sharpness-Aware Minimization. Foret et al. (2021); Zheng et al. (2021) concurrently propose to minimize the loss at the perturbed from current parameter towards the worst direction to improve generalization. Wu et al. (2020) propose an almost identical method for a different purpose, robust generalization of adversarial training. Kwon et al. (2021) propose a different metric for SAM to fix the rescaling problem pointed out by Dinh et al. (2017). Liu et al. (2022) propose a more computationally efficient version Published as a conference paper at ICLR 2023 Figure 1: Visualization of the different biases of different sharpness notions on a 4D-toy example. Let F1, F2 : R2 R+ be two positive functions satisfying that F1 > F2 on [0, 1]2. For x R4, consider loss L(x) = F1(x1, x2)x2 3 + F2(x1, x2)x2 4. The loss L has a zero loss manifold {x3 = x4 = 0} of codimension M = 2 and the two non-zero eigenvalues of 2L of any point x on the manifold are λ1( 2L(x)) = F1(x1, x2) and λ2( 2L(x)) = F2(x1, x2). We test three optimization algorithms on this 4D-toy model with small learning rates. They all quickly converge to zero loss, i.e., x3(t), x4(t) 0, and after that x1(t), x2(t) still change slowly, i.e., moving along the zero loss manifold. We visualize the loss restricted to (x3, x4) as the 3D shape at various (x1, x2) s where x1 = x1(t), x2 = x2(t) follows the trajectories of the three algorithms. In other words, each of the 3D surface visualize the function g(x3, x4) = L(x1(t), x2(t), x3, x4). As our theory predicts, (1) Full-batch SAM (Equation 3) finds the minimizer with the smallest top eigenvalue, F1(x1, x2); (2) GD on ascent-direction loss LAsc ρ (Equation 2) finds the minimizer with the smallest bottom eigenvalue, F2(x1, x2); (3) 1-SAM (Equation 13) (with L0(x) = F1(x1, x2)x2 3 and L1(x) = F2(x1, x2)x2 4) finds the minimizer with the smallest trace of Hessian, F1(x1, x2) + F2(x1, x2). See more details in Appendix B. of SAM. Zhuang et al. (2022) proposes a variant of SAM, which improves generalization by simultaneously optimizing the surrogate gap and the sharpness-aware loss. Zhao et al. (2022) propose to improve generalization by penalizing gradient norm. Their proposed algorithm can be viewed as a generalization of SAM. Andriushchenko et al. (2022) study a variant of SAM where the step size of ascent step is ρ instead of ρ L(x) 2 . They show that for a simple model this variant of SAM has a stronger regularization effect when batch size is 1 compared to the full-batch case and argue that this might be the explanation that SAM generalizes better with small batch sizes. More related works are discussed in Appendix A. 3 NOTATIONS AND ASSUMPTIONS For any natural number k, we say a function is Ck if it is k-times continuously differentiable and is C k if its kth order derivatives are locally lipschitz. We say a subset of RD is compact if each of its open covers has a finite subcover. It is well known that a subset of RD is compact if and only if it is closed and bounded. For any positive definite symmetric matrix A RD D, define {λi(A), vi(A)}i [D] as all its eigenvalues and eigenvectors satisfying λ1(A) λ2(A)... λD(A) and vi(A) 2 = 1. For any mapping F, we define F(x) as the Jacobian where [ F(x)]ij = j Fi(x). Thus the directional derivative of F along the vector u at x can be written as F(x)u. We further define the second order directional derivative of F along the vectors u and v at x, 2F(x)[u, v], ( F u)(x)v, that is, the directional derivative of F u along the vector v at x. Given a C1 submanifold (Definition C.1) Γ of RD and a point x Γ, define Px,Γ as the projection operator onto the manifold of the normal space of Γ at x and P x,Γ = ID Px,Γ. We fix our initialization as xinit and our loss function as L : RD R. Given the loss function, its gradient flow is denoted by mapping ϕ : RD [0, ) RD. Here, ϕ(x, τ) denotes the iterate at time τ of a gradient flow starting at x and is defined as the unique solution of ϕ(x, τ) = x R τ 0 L(ϕ(x, t))dt, x RD. We further define the limiting map Φ as Φ(x) = limτ ϕ(x, τ), that is, Φ(x) denotes the convergent point of the gradient flow starting from x. When L(x) is small, Φ(x) and x are near. Hence in our analysis, we regularly use Φ(x(t)) as a surrogate to analyze the dynamics of x(t). Lemma 3.1 is an important property of Φ from Li et al. (2021) (Lemma C.2), which is repeatedly used in our analysis. The proof is shown in Appendix F. Lemma 3.1. For any x at which Φ is defined and differentiable, we have that Φ(x) L(x) = 0. Recent empirical studies have shown that there are essentially no barriers in loss landscape between different minimizers, that is, the set of minimizers are path-connected (Draxler et al., 2018; Garipov et al., 2018). Motivated by this empirical discovery, we make the assumption below following Fehrman et al. (2020); Li et al. (2021); Arora et al. (2022), which is theoretically justified by Cooper (2018) under a generic setting. Assumption 3.2. Assume loss L : RD R is C4, and there exists a C2 submanifold Γ of RD that is a (D M)-dimensional for some integer 1 M D, where for all x Γ, x is a local minimizer of L and rank( 2L(x)) = M. Published as a conference paper at ICLR 2023 Though our analysis for the full-batch setting is performed under the general and abstract setting, Assumption 3.2, our analysis for the stochastic setting uses a more concrete one, Setting 5.1, where we can prove that Assumption 3.2 holds. (see Theorem 5.2) Definition 3.3 (Attraction Set). Let U be the attraction set of Γ under gradient flow, that is, a neighborhood of Γ containing all points starting from which gradient flow w.r.t. loss L converges to some point in Γ, or mathematically, U {x RD|Φ(x) exists and Φ(x) Γ}. It can be shown that for a minimum loss manifold, the rank of Hessian plus the dimension of the manifold is at most the environmental dimension D, and thus our assumption about Hessian rank essentially says the rank is maximal. Assumption 3.2 implies that U is open and Φ is C 2 on U (Arora et al., 2022, Lemma B.15). 4 EXPLICIT AND IMPLICIT BIAS IN THE FULL-BATCH SETTING In this section, we present our main results in the full-batch setting. Section 4.1 provides characterization of explicit bias of worst-direction, ascent-dircetion, and average-direction sharpness. In particular, we show that ascent-direction sharpness and worst-direction sharpness have different explicit biases. However, it turns out the explicit bias of ascent-direction sharpness is not the effective bias of SAM (that approximately optimizes the ascent-direction sharpness), because the particular implementation of SAM imposes additional, different biases, which is the main focus of Section 4.2. We provide our main theorem in the full-batch setting, that SAM implicitly minimizes the worst-direction sharpness, via characterizing its limiting dynamics as learning rate ρ and η goes to 0 with a Riemmanian gradient flow with respect to the top eigenvalue of the Hessian of the loss on the manifold of local minimizers. In Section 4.3, we sketch the proof of the implicit bias of SAM and identify a key property behind the implicit bias, which we call the implicit alignment between the gradient and the top eigenvector of the Hessian. 4.1 WORSTAND ASCENT-DIRECTION SHARPNESS HAVE DIFFERENT EXPLICIT BIASES In this subsection, we show that the explicit biases of three notions of sharpness are all different under Assumption 3.2. We first recap the heuristic derivation of ascent-direction sharpness RAsc ρ . The intuition of approximating RMax ρ by RAsc ρ comes from the following Taylor expansions (Foret et al., 2021; Wu et al., 2020). Consider any compact set, for sufficiently small ρ, the following holds uniformly for all x in the compact set: RMax ρ (x) = sup v 2 1 L(x + ρv) L(x) = sup v 2 1 ρv L(x) + ρ2 2 v 2L(x)v + O(ρ3) , (5) RAsc ρ (x)=L x + ρ L(x) L(x) 2 L(x) =ρ L(x) 2+ ρ2 2 L(x) 2L(x) L(x) L(x) 2 2 +O(ρ3) . (6) Here, the preference among the local or global minima is what we are mainly concerned with. Since sup v 2 1 v L(x) = L(x) 2 when L(x) 2 > 0, the leading terms in Equations 5 and 6 are both the first order term, ρ L(x) 2, and are the same. However, it is erroneous to think that the first order term decides the explicit bias, as the first order term L(x) 2 vanishes at the local minimizers of the loss L and thus the second order term becomes the leading term. Any global minimizer x of the original loss L is an O(ρ2)-approximate minimizer of the sharpness-aware loss because L(x) = 0. Therefore, the sharpness-aware loss needs to be of order ρ2 so that we can guarantee the second-order terms in Equation 5 and/or Equation 6 to be non-trivially small. Our main result in this subsection (Theorem 4.2) gives an explicit characterization for this phenomenon. The corresponding explicit biases for each type of sharpness is given below in Definition 4.1. As we will see later, they can be derived from a general notion of limiting regularizer (Definition 4.3). Definition 4.1. For x RD, we define SMax(x) = λ1( 2L(x))/2, SAsc(x) = λM( 2L(x))/2 and SAvg(x) = Tr( 2L(x))/(2D). Theorem 4.2. Under Assumption 3.2, let U be any bounded open set such that its closure U U and U Γ U Γ. For any type {Max, Asc, Avg} and any optimality gap > 0, there is a function ϵ : R+ R+ with limρ 0 ϵ(ρ) = 0, such that for all sufficiently small ρ > 0 and all u U satisfying that L(u) + Rtype ρ (u) inf x U L(x) + Rtype ρ (x) ρ2, 1 it holds L(u) infx U L(x) ( + ϵ(ρ))ρ2 and that Stype(u) infx U Γ Stype(x) [ ϵ(ρ), + ϵ(ρ)]. 1We note that RAsc ρ (x) is undefined when L(x) 2 = 0. In such cases, we set RAsc ρ (x) = . Published as a conference paper at ICLR 2023 Theorem 4.2 suggests a sharp phase transition of the property of the solution of minx L(x) + Rρ(x) when the optimization error drops from ω(ρ2) to O(ρ2). When the optimization error is larger than ω(ρ2), no regularization effect happens and any minimizer satisfies the requirement. When the error becomes O(ρ2), there is a non-trivial restriction on the coefficients in the second-order term. Next we give a heuristic derivation for the above defined Stype. First, for worstand average-direction sharpness, the calculations are fairly straightforward and well-known in literature (Keskar et al., 2016; Kaur et al., 2022; Zhuang et al., 2022; Orvieto et al., 2022), and we sketch them here. In the limit of perturbation radius ρ 0, we know that the minimizer of the sharpness-aware loss will also converges to Γ, the manifold of minimizers of the original loss L. Thus to decide to which x Γ the minimizers will converge to as ρ 0, it suffices to take Taylor expansion of LAsc ρ or LAvg ρ at each x Γ and compare the second-order coefficients, e.g., we have that RAvg ρ (x) = ρ2 2DTr( 2L(x)) + O(ρ3) and RMax ρ (x) = ρ2 2 λ1( 2L(x)) + O(ρ3) by Equation 5. However, the analysis for ascent-direction sharpness is more tricky because RAsc ρ (x) = for any x Γ and thus is not continuous around such x. Thus we have to aggregate information from neighborhood to capture the explicit bias of Rρ around manifold Γ. This motivates the following definition of limiting regularizer which allows us to compare the regularization strength of Rρ around each point on manifold Γ as ρ 0. Definition 4.3 (Limiting Regularizer). We define the limiting regularizer of {Rρ} as the function2 S : Γ R, S(x) = lim ρ 0 lim r 0 inf x x 2 r Rρ(x )/ρ2. To minimize RAsc ρ around x, we can pick x x satisfying that L(x ) 2 0 yet strictly being non- zero. By Equation 6, we have RAsc ρ (x ) ρ2 2 L(x ) 2L(x) L(x ) L(x ) 2 2 . Here the crucial step of the proof is that because of Assumption 3.2, L(x)/ L(x) 2 must almost lie in the column span of 2L(x), which implies that infx L(x ) 2L(x) L(x )/ L(x ) 2 2 ρ 0 λM( 2L(x)), where rank( 2L(x)) = M by Assumption 3.2. The above alignment property between the gradient and the column space of Hessian can be checked directly for any non-negative quadratic function. The maximal Hessian rank assumption in Assumption 3.2 ensures that this property extends to general losses. We defer the proof of Theorem 4.2 into Appendix G.1, where we develop a sufficient condition where the notion of limiting regularizer characterizes the explicit bias of Rρ as ρ 0. 4.2 SAM PROVABLY DECREASES WORST-DIRECTION SHARPNESS Though ascent-direction sharpness has different explicit bias from worst-direction sharpness, in this subsection we will show that surprisingly, SAM (Equation 3), a heuristic method designed to minimize ascentdirection sharpness, provably decreases worst-direction sharpness. The main result here is an exact characterization of the trajectory of SAM (Equation 3) via the following ordinary differential equation (ODE) (Equation 7), when learning rate η and perturbation radius ρ are small and the initialization x(0) = xinit is in U, the attraction set of manifold Γ. X(τ) = X(0) 1 s=0 P X(s),Γ λ1( 2L(X(s)))ds, X(0) = Φ(xinit). (7) We assume ODE (Equation 7) has a solution till time T3, that is, Equation 7 holds for all t T3. We call the solution of Equation 7 the limiting flow of SAM, which is exactly the Riemannian Gradient Flow on the manifold Γ with respect to the loss λ1( 2L( )). In other words, the ODE (Equation 7) is essentially a projected gradient descent algorithm with loss λ1( 2L( )) on the constraint set Γ and an infinitesimal learning rate. Note λ1( 2L(x)) may not be differentiable at x if λ1( 2L(x)) = λ2( 2L(x)), thus to ensure Equation 7 is well-defined, we assume there is a positive eigengap for L on Γ.3 Assumption 4.4. For all x Γ, there exists a positive eigengap, i.e., λ1( 2L(x)) > λ2( 2L(x)). Theorem 4.5 is the main result of this section, which is a direct combination of Theorems I.1 and I.3. The proof is deferred to Appendix I.3. 2Here we implicitly assume the zeroth and first order term varnishes, which holds for all three sharpness notions. 3In fact we only need to assume the positive eigengap along the solution of the ODE. If Γ doesn t satisfy Assumption 4.4, we can simply perform the same analysis on its submanifold {x Γ | eigengap is positive at x}. Published as a conference paper at ICLR 2023 Theorem 4.5 (Main). Let {x(t)} be the iterates of full-batch SAM (Equation 3) with x(0) = xinit U. Under Assumptions 3.2 and 4.4, for all η, ρ such that η ln(1/ρ) and ρ/η are sufficiently small, the dynamics of SAM can be characterized in the following two phases: Phase I: (Theorem I.1) Full-batch SAM (Equation 3) follows Gradient Flow with respect to L until entering an O(ηρ) neighborhood of the manifold Γ in O(ln(1/ρ)/η) steps; Phase II: (Theorem I.3) Under a mild non-degeneracy assumption (Assumption I.2) on the initial point of phase II, full-batch SAM (Equation 3) tracks the solution X of Equation 7, the Riemannian Gradient Flow with respect to the loss λ1( 2L( )) in an O(ηρ) neighborhood of manifold Γ. Quantitatively, the approximation error between the iterates x and the corresponding limiting flow X is O(η ln(1/ρ)), that is, x T3/(ηρ2) X(T3) 2 = O(η ln(1/ρ)) . Moreover, the angle between L x( T3 ηρ2 and the top eigenspace of 2L(x( T3 ηρ2 )) is O(ρ). Theorem 4.5 shows that SAM decreases the largest eigenvalue of Hessian of loss locally around the manifold of local minimizers. Phase I uses standard approximation analysis as in Hairer et al. (2008). In Phase II, as T3 is arbitrary, the approximation and alignment properties hold simultaneously for all X(t) along the trajectory, provided that η ln(1/ρ) and ρ/η are sufficiently small. The subtlety here is that the threshold of being sufficiently small on η ln(1/ρ) and ρ/η actually depends on T3, which decreases when T3 0 or . We defer the proof of Theorem 4.5 to Appendix I. As a corollary of Theorem 4.5, we can also show that the largest eigenvalue of the limiting flow closely tracks the worst-direction sharpness. Corollary 4.6. In the setting of Theorem 4.5, the difference between the worst-direction sharpness of the iterates and the corresponding scaled largest eigenvalues along the limiting flow is at most O(ηρ2 ln(1/ρ)). That is, RMax ρ (x( T3/ηρ2 )) ρ2λ1( 2L(X(T3))/2 = O(ηρ2 ln(1/ρ)) . (8) Since η ln(1/ρ) is assumed to be sufficiently small, the error O(η ln(1/ρ) ρ2) is only o(ρ2), meaning that penalizing the top eigenvalue on the manifold does lead to non-trivial reduction of worst-direction sharpness, in the sense of Section 4.1. Hence we can show that full-batch SAM (Equation 3) provably minimizes worst-direction sharpness around the manifold if we additionally assume the limiting flow converges to a minimizer of the top eigenvalue of Hessian in the following Corollary 4.7. Corollary 4.7. Under Assumptions 3.2 and 4.4, define U as in Theorem 4.2 and suppose X( ) = lim t X(t) exists and is a minimizer of λ1( 2L(x)) in U Γ. Then for all ϵ > 0, there exists Tϵ > 0, such that for all ρ, η such that η ln(1/ρ) and ρ/η are sufficiently small, we have that LMax ρ (x( Tϵ/(ηρ2) )) ϵρ2 + inf x U LMax ρ (x) . We defer the proof of Corollaries 4.6 and 4.7 to Appendix I.4. 4.3 ANALYSIS OVERVIEW FOR SHARPNESS REDUCTION IN PHASE II OF THEOREM 4.5 Now we give an overview of the analysis for the trajectory of full-batch SAM (Equation 3) in Phase II (in Theorem 4.5). The framework of the analysis is similar to Arora et al. (2022); Lyu et al. (2022); Damian et al. (2021), where the high-level idea is to use Φ(x(t)) as a proxy for x(t) and study the dynamics of Φ(x(t)) via Taylor expansion. Following the analysis in Arora et al. (2022) we can show Equation 9 using Taylor expansion, starting from which we will discuss the key innovation in this paper regarding implicit Hessian-gradient alignment. We defer its intuitive derivation into Appendix I.5. Φ(x(t + 1)) Φ(x(t))= ηρ2 2 Φ(x(t)) 2( L)(x(t)) L(x(t)) L(x(t)) 2 , L(x(t)) L(x(t)) 2 +O(η2ρ2 + ηρ3) . (9) Now, to understand how Φ(x(t)) moves over time, we need to understand what the direction of the RHS of Equation 9 corresponds to we will prove that it corresponds to the Riemannian gradient of the loss function λ1( 2L(x)) at x = Φ(x(t)). To achieve this, the key is to understand the direction L(x(t)) L(x(t)) 2 . It turns out that we will prove L(x(t)) L(x(t)) 2 is close to the top eigenvector of the Hessian up to sign flip, that is Published as a conference paper at ICLR 2023 L(x(t)) L(x(t)) 2 s v1( 2L(x)) 2 O(ρ) for some s { 1, 1}. We call this phenomenon Hessian-gradient alignment and will discuss it in more detail at the end of this subsection. Using this property, we can proceed with the derivation (detailed in Appendix I.5): Φ(x(t + 1)) Φ(x(t)) = ηρ2 2 Φ(Φ(x(t))) λ1( 2L(Φ(x(t)))) + O(η2ρ2 + ηρ3), (10) Implicit Hessian-gradient Alignment. It remains to explain why the gradient implicitly aligns to the top eigenvector of the Hessian, which is the key component of the analysis in Phase II. The proof strategy here is to first show alignment for a quadratic loss function, and then generalize its proof to general loss functions satisfying Assumption 3.2. Below we first give the formal statement of the implicit alignment on quadratic loss, Theorem 4.8 and defer the result for general case (Lemma I.19) to appendix. Note this alignment property is an implicit property of the SAM algorithm as it is not explicitly enforced by the objective that SAM is intended to minimize, LAsc ρ . Indeed optimizing LAsc ρ would rather explicitly align gradient to the smallest non-zero eigenvector (See proofs of Theorem G.5)! Theorem 4.8. Suppose A is a positive definite symmetric matrix with unique top eigenvalue. Consider running full-batch SAM (Equation 3) on loss L(x) := 1 2x T Ax as in Equation 11 below. x(t + 1) = x(t) ηA x(t) + ρAx(t)/ Ax(t) 2 . (11) Then, for almost every x(0), we have x(t) converges in direction to v1(A) up to a sign flip and limt x(t) 2 = ηρλ1(A) 2 ηλ1(A) with ηλ1(A) < 1. The proof of Theorem 4.8 relies on a two-phase analysis of the behavior of Equation 11, where we first show that x(t) enters an invariant set from any initialization and in the second phase, we construct a potential function to show alignment. The proof is deferred to Appendix H. Below we briefly discuss why the case with general loss is closely related to the quadratic loss case. We claim that, in the general loss function case, the analog of Equation 11 is the update rule for the gradient: L(x(t + 1))= L(x(t)) η 2L(x(t)) L(x(t)) + ρ 2L(x(t)) L(x(t)) L(x(t))) 2 + O(ηρ2) . (12) We first note that indeed in the quadratic case where L(x) = Ax and 2L(x) = A, Equation 12 is equivalent to Equation 11 because they only differ by a multiplicative factor A on both sides. We derive its intuitive derivation into Appendix I.5. 5 EXPLICIT AND IMPLICIT BIASES IN THE STOCHASTIC SETTING In practice, people usually use SAM in the stochastic mini-batch setting, and the test accuracy improves as the batch size decreases (Foret et al., 2021). Towards explaining this phenomenon, Foret et al. (2021) argue intuitively that stochastic SAM minimizes stochastic worst-direction sharpness. Given our results in Section 4, it is natural to ask if we can justify the above intuition by showing the Hessian-gradient alignment in the stochastic setting. Unfortunately, such alignment is not possible in the most general setting. Yet when the batch size is 1, we can prove rigorously in Section 5.2 that stochastic SAM minimizes stochastic worst-direction sharpness, which is the expectation of the worst-direction sharpness of loss over each data (defined in Section 5.1), which is the main result in this section. We stress that the stochastic worst-direction sharpness has a different explicit bias to the worst-direction sharpness, which full-batch SAM implicitly penalizes. When perturbation radius ρ 0, the former corresponds to Tr( 2L( )), the same as averagedirection sharpness, and the latter corresponds to λ1( 2L( )). Below we start by introducing our setting for SAM with batch size 1, or 1-SAM. We still need Assumption 3.2 in this section. We first analyze the explicit bias of the stochastic ascentand worst-direction sharpness in Section 5.1 via the tools developed in Section 4.1. It turns out they are all proportional to the trace of hessian as ρ 0. In Section 5.2, we show that 1-SAM penalizes the trace of Hessian. Below we formally state our setting for stochastic loss of batch size one (Setting 5.1). Setting 5.1. Let the total number of data be M. Let fk(x) be the model output on the k-th data where fk is a C4-smooth function and yk be the k-th label, for k = 1, . . . , M. We define the loss on the k-th data as Lk(x) = ℓ(fk(x), yk) and the total loss L = PM k=1 Lk/M, where function ℓ(y , y) is C4-smooth in y . We also assume for any y R, it holds that arg miny R ℓ(y , y) = y and that 2ℓ(y ,y) ( y )2 |y =y > 0. Finally, we denote the set of global minimizers of L with full-rank Jacobian by Γ and assume that it is non-empty, that is, Γ x RD | fk(x) = yk, k [M] and { fk(x)}M k=1 are linearly independent = . Published as a conference paper at ICLR 2023 We remark that given training data (i.e., {fk}M k=1), Γ defined above is just equal to the set of global minimizers, x RD | fk(x) = yk, k [M] , except for a zero measure set of labels (yk)M k=1 when fk are C smooth, by Sard s Theorem. Thus Cooper (2018) argued that the global minimizers form a differentiable manifold generically if we allow perturbation on the labels. In this work we do not make such an assumption for labels. Instead, we consider the subset of the global minimizers with full-rank Jacobian, Γ. A standard application of implicit function theorem implies that Γ defined in Setting 5.1 is indeed a manifold. (See Theorem 5.2, whose proof is deferred into Appendix E.1) Theorem 5.2. Loss L, set Γ and integer M defined in Setting 5.1 satisfy Assumption 3.2. 1-SAM: We use 1-SAM as a shorthand for SAM on a stochastic loss with batch size 1 as below Equation 13, where kt is sampled i.i.d from uniform distribution on [M]. 1-SAM : x(t + 1) = x(t) η Lkt x + ρ Lkt(x)/ Lkt(x) 2 . (13) 5.1 STOCHASTIC WORST-, ASCENTAND AVERAGEDIRECTION SHARPNESS HAVE THE SAME EXPLICIT BIASES AS AVERAGE DIRECTION SHARPNESS Similar to the full-batch case, we use LMax k,ρ , LAsc k,ρ, LAvg k,ρ to denote the corresponding sharpness-aware loss for Lk and RMax k,ρ , RAsc k,ρ, RAvg k,ρ to denote corresponding sharpness for Lk respectively (defined as Equations 1, 2 and 4 with L replaced by Lk). We further use stochastic worst-, ascentand average-direction sharpness to denote Ek[RMax k,ρ ], Ek[RAsc k,ρ] and Ek[RAvg k,ρ]. Unlike the full-batch setting, these three sharpness notions have the same explicit biases, or more precisely, they have the same limiting regularizers (up to some scaling factor). Theorem 5.3. The limiting regularizers of three notions of stochastic sharpness, denoted by e SMax, e SAsc, e SAvg, satisfy that e SMax(x) = e SAsc(x) = D e SAvg(x) = Tr( 2L(x))/2. Furthermore, define U in the same way as in Theorem 4.2 . For any type {Max, Asc, Avg}, it holds that if for some u U , L(u) + Ek[Rtype k,ρ (u)] inf x U L(x) + Ek[Rtype k,ρ (x)] + ϵρ2,4 then we have that L(u) infx U L(x) ϵρ2 + o(ρ2) and that e Stype(u) infx U Γ e Stype(x) ϵ + o(1). We defer the proof of Theorem 5.3 to Appendix G.4. Unlike in the full-batch setting where the implicit regularizer of ascent-direction sharpness and worst-direction sharpness have different explicit bias, here they are the same because there is no difference between the maximum and minimum of its non-zero eigenvalue for rank-1 Hessian of each individual loss Lk, and that the average of limiting regularizers is equal to the limiting regularizer of the average regularizer by definition. 5.2 STOCHASTIC SAM MINIMIZES AVERAGE-DIRECTION SHARPNESS This subsection aims to show that the implicit bias of 1-SAM (Equation 13) is minimizing the averagedirection sharpness for small perturbation radius ρ and learning rate η, which has the same implicit bias as all three notions of stochastic sharpness do (Theorem 5.3). As an analog of the analysis in Section 4.3, which shows full-batch SAM minimizes worst-direction sharpness, analysis in this section conceptually shows that 1-SAM minimizes the stochastic worst-direction sharpness. Mathematically, we prove that the trajectory of 1-SAM tracks the following Riemannian gradient flow (Equation 14) with respect to their limiting regularize Tr( 2L( )) on the manifold for sufficiently small η and ρ and thus penalizes stochastic worst-direction sharpness (of batch size 1). We assume the ODE (Equation 14) has a solution till time T3. X(τ) = X(0) 1 s=0 P X(s),Γ Tr( 2L(X(s)))ds, X(0) = Φ(xinit). (14) Theorem 5.4. Let {x(t)} be the iterates of 1-SAM (Equation 13) and x(0) = xinit U, then under Setting 5.1, for almost every xinit, for all η and ρ such that (η +ρ) ln(1/ηρ) is sufficiently small, with probability at least 1 O(ρ) over the randomness of the algorithm, the dynamics of 1-SAM (Equation 13) can be split into two phases: Phase I (Theorem J.1): 1-SAM follows Gradient Flow with respect to L until entering an O(ηρ) neighborhood of the manifold Γ in O(ln(1/ρη)/η) steps; 4We note that RAsc ρ (x) is undefined when L(x) 2 = 0. In such cases, we set RAsc ρ (x) = . Published as a conference paper at ICLR 2023 Phase II (Theorem J.2): 1-SAM tracks the solution of Equation 14, X, the Riemannian gradient flow with respect to Tr( 2L( )) in an O(ηρ) neighborhood of manifold Γ. Quantitatively, the approximation error between the iterates x and the corresponding limiting flow X is O(η1/2 + ρ), that is, x( T3/(ηρ2) ) X(T3) 2 = O(η1/2 + ρ). The high-level intuition for the Phase II result of Theorem 5.4 is that Hessian-gradient alignment holds true for every stochastic loss Lk along the trajectory of 1-SAM and therefore by Taylor expansion (the same argument in Section 4.3), at each step Φ(x(t)) moves towards the negative (Riemannian) gradient of λ1( 2Lkt) where kt is the index of randomly sampled data, or the limiting regularizer of the worst-direction sharpness of Lkt. Averaging over a long time, the moving direction becomes the negative (Riemmanian) gradient of Ekt[λ1( 2Lkt)], which is the limiting regularizer of stochastic worst-direction sharpness and equals to Tr( 2L) by Theorem 5.3. The reason that Hessian-gradient alignment holds under Setting 5.1 is that the Hessian of each stochastic loss Lk at minimizers p Γ, 2Lk(p) = 2ℓ(y ,yk) ( y )2 |y =fk(p) fk(p)( fk(p)) (Lemma J.15), is exactly rank-1, which enforces the gradient Lk(x) 2Lk(Φ(x))(x Φ(x)) to (almost) lie in the top (which is also the unique) eigenspace of 2Lk(Φ(x)). Lemma 5.5 formally states this property. Lemma 5.5. Under Setting 5.1, for any p Γ and k [M], it holds that fk(p) = 0 and that there is an open set V containing p, satisfying that x V, Lk (x) = 0 = s { 1, 1}, Lk (x) Lk (x) = s fk(p) fk(p) 2 + O( x p 2). Corollaries 5.6 and 5.7 below are stochastic counterparts of Corollaries 4.6 and 4.7, saying that the trace of Hessian are close to the stochastic worst-direction sharpness along the limiting flow (14), and therefore when the limiting flow converges to a local minimizer of trace of Hessian, 1-SAM (Equation 13) minimizes the average-direction sharpness. We defer the proofs of Corollaries 5.6 and 5.7 to Appendix J.4. Corollary 5.6. Under the condition of Theorem 5.4, we have that with probability 1 O( η + ρ), the difference between the stochastic worst-direction sharpness of the iterates and the corresponding scaled trace of Hessian along the limiting flow is at most O (η1/4 + ρ1/4)ρ2 , that is, Ek[RMax k,ρ (x( T3/(ηρ2) ))] ρ2Tr( 2L(X(T3)))/2 = O (η1/4 + ρ1/4)ρ2 . Corollary 5.7. Define U as in Theorem 4.2, suppose X( ) = lim t X(t) exists and is a minimizer of Tr( 2L(x))) in U Γ. Then for all ϵ > 0, there exists a constant Tϵ > 0, such that for all ρ, η such that (η + ρ) ln(1/ηρ) are sufficiently small, we have that with probability 1 O( η + ρ), Ek[LMax k,ρ (x( Tϵ/(ηρ2) ))] ϵρ2 + inf x U Ek[LMax k,ρ (x)] . 6 CONCLUSION In this work, we have performed a rigorous mathematical analysis of the explicit bias of various notions of sharpness when used as regularizers and the implicit bias of the SAM algorithm. In particular, we show the explicit biases of worst-, ascentand average-direction sharpness around the manifold of minimizers are minimizing the largest eigenvalue, the smallest nonzero eigenvalue, and the trace of Hessian of the loss function. We show that in the full-batch setting, SAM provably decreases the largest eigenvalue of Hessian, while in the stochastic setting when batch size is 1, SAM provably decreases the trace of Hessian. The most interesting future work is to generalize the current analysis for stochastic SAM to arbitrary batch size. This is challenging because, without the alignment property which holds automatically with batch size 1, such an analysis essentially requires understanding the stationary distribution of the gradient direction along the SAM trajectory. It is also interesting to incorporate other features of modern deep learning like normalization layers, momentum, and weight decay into the current analysis. Another interesting open question is to further bridge the difference between generalization bounds and the implicit bias of the optimizers. Currently, the generalization bounds in Wu et al. (2020); Foret et al. (2020) only work for the randomly perturbed model. Moreover, the bound depends on the average sharpness with finite ρ, whereas the analysis of this paper only works for infinitesimal ρ. It s an interesting open question whether the generalization error of the model (without perturbation) can be bounded from above by some function of the training loss, norm of the parameters, and the trace of the Hessian. Published as a conference paper at ICLR 2023 ACKNOWLEDGEMENTS We thank Jingzhao Zhang for helpful discussions. The authors would like to thank the support from NSF IIS 2045685. Maksym Andriushchenko and Nicolas Flammarion. Towards understanding sharpness-aware minimization. In International Conference on Machine Learning, pp. 639 668. PMLR, 2022. 3 Sanjeev Arora, Zhiyuan Li, and Abhishek Panigrahi. Understanding gradient descent on edge of stability in deep learning. ar Xiv preprint ar Xiv:2205.09745, 2022. 3, 4, 6, 14, 15, 16, 20, 21, 22, 58 Peter L Bartlett, Philip M Long, and Olivier Bousquet. The dynamics of sharpness-aware minimization: Bouncing across ravines and drifting towards wide minima. ar Xiv preprint ar Xiv:2210.01513, 2022. 14, 15 Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regularization for deep neural networks driven by an ornstein-uhlenbeck like process. ar Xiv preprint ar Xiv:1904.09080, 2019. 14, 16 Vivek S Borkar. Stochastic approximation: a dynamical systems viewpoint, volume 48. Springer, 2009. 70 Vivek S Borkar, Jervis Pinto, and Tarun Prabhu. A new learning algorithm for optimal stopping. Discrete Event Dynamic Systems, 19(1):91 113, 2009. 15 Jeremy M. Cohen, Simran Kaur, Yuanzhi Li, J. Zico Kolter, and Ameet Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability, 2021. 14 Jeremy M Cohen, Behrooz Ghorbani, Shankar Krishnan, Naman Agarwal, Sourabh Medapati, Michal Badura, Daniel Suo, David Cardoze, Zachary Nado, George E Dahl, et al. Adaptive gradient methods at the edge of stability. ar Xiv preprint ar Xiv:2207.14484, 2022. 14 Yaim Cooper. The loss landscape of overparameterized neural networks. ar Xiv preprint ar Xiv:1804.10200, 2018. 3, 8 Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat global minimizers, 2021. 6, Alex Damian, Eshaan Nichani, and Jason D Lee. Self-stabilization: The implicit bias of gradient descent at the edge of stability. ar Xiv preprint ar Xiv:2209.15594, 2022. 14 Chandler Davis and William Morton Kahan. The rotation of eigenvectors by a perturbation. iii. SIAM Journal on Numerical Analysis, 7(1):1 46, 1970. 69 Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1019 1028. JMLR. org, 2017. 1, 2 Felix Draxler, Kambis Veschgini, Manfred Salmhofer, and Fred Hamprecht. Essentially no barriers in neural network energy landscape. In International conference on machine learning, pp. 1309 1318. PMLR, 2018. 3 John C Duchi and Feng Ruan. Stochastic methods for composite and weakly convex optimization problems. SIAM Journal on Optimization, 28(4):3229 3259, 2018. 15 Gintare Karolina Dziugaite and Daniel M Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. ar Xiv preprint ar Xiv:1703.11008, 2017. 1, 2 Benjamin Fehrman, Benjamin Gess, and Arnulf Jentzen. Convergence rates for the stochastic gradient descent method for non-convex objective functions. Journal of Machine Learning Research, 21:136, 2020. 2, 3 Published as a conference paper at ICLR 2023 Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. ar Xiv preprint ar Xiv:2010.01412, 2020. 9, 16 Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. 1, 2, 4, 7 Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry P Vetrov, and Andrew G Wilson. Loss surfaces, mode connectivity, and fast ensembling of dnns. Advances in neural information processing systems, 31, 2018. 3 E. Hairer, S.P. Nørsett, and G. Wanner. Solving Ordinary Differential Equations I: Nonstiff Problems. Springer Series in Computational Mathematics. Springer Berlin Heidelberg, 2008. ISBN 9783540566700. URL https://books.google.com/books?id=F93u7Vc SRy YC. 6 Thomas P Hayes. A large-deviation inequality for vector-valued martingales. Combinatorics, Probability and Computing, 2003. 69 Sepp Hochreiter and J urgen Schmidhuber. Flat minima. Neural Computation, 9(1):1 42, 1997. 2 Roger A. Horn and Charles R. Johnson. Matrix analysis. Cambridge university press, 2012. 69 Stanisław Jastrzebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amos Storkey. Three factors influencing minima in sgd. ar Xiv preprint ar Xiv:1711.04623, 2017. 2 Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. ar Xiv preprint ar Xiv:1912.02178, 2019. 1, 2 Simran Kaur, Jeremy Cohen, and Zachary C Lipton. On the maximum hessian eigenvalue and generalization. ar Xiv preprint ar Xiv:2206.10654, 2022. 5 Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. ar Xiv preprint ar Xiv:1609.04836, 2016. 1, 2, 5 Harold Kushner and G George Yin. Stochastic approximation and recursive algorithms and applications, volume 35. Springer Science & Business Media, 2003. 15 Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning, pp. 5905 5914. PMLR, 2021. 2 Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and adaptive stochastic gradient algorithms. In International Conference on Machine Learning, pp. 2101 2110. PMLR, 2017. 15 Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and dynamics of stochastic gradient algorithms i: Mathematical foundations. The Journal of Machine Learning Research, 20(1):1474 1520, 2019. 15 Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after sgd reaches zero loss? a mathematical framework. In International Conference on Learning Representations, 2021. 2, 3, 14, 16, 21, 22 Zhouzi Li, Zixuan Wang, and Jian Li. Analyzing sharpness along gd trajectory: Progressive sharpening and edge of stability. ar Xiv preprint ar Xiv:2207.12678, 2022. 14 Yong Liu, Siqi Mai, Xiangning Chen, Cho-Jui Hsieh, and Yang You. Towards efficient and scalable sharpness-aware minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12360 12370, 2022. 2 Kaifeng Lyu, Zhiyuan Li, and Sanjeev Arora. Understanding the generalization benefit of normalization layers: Sharpness reduction. ar Xiv preprint ar Xiv:2206.07085, 2022. 6, 14 Published as a conference paper at ICLR 2023 Chao Ma and Lexing Ying. On linear stability of sgd and input-smoothness of neural networks. Advances in Neural Information Processing Systems, 34:16805 16817, 2021. 14 Chao Ma, Lei Wu, and Lexing Ying. The multiscale structure of neural network loss functions: The effect on optimization and origin. ar Xiv preprint ar Xiv:2204.11326, 2022. 14 Jan R Magnus. On differentiating eigenvalues and eigenvectors. Econometric theory, 1(2):179 191, 1985. Stephan Mandt, Matthew D Hoffman, and David M Blei. Stochastic gradient descent as approximate bayesian inference. Journal of Machine Learning Research, 18:1 35, 2017. 15 Behnam Neyshabur, Srinadh Bhojanapalli, David Mc Allester, and Nati Srebro. Exploring generalization in deep learning. In Advances in Neural Information Processing Systems, pp. 5947 5956, 2017. 1, 2 Matthew D Norton and Johannes O Royset. Diametrical risk minimization: Theory and computations. Machine Learning, pp. 1 19, 2021. 1 Antonio Orvieto, Anant Raj, Hans Kersting, and Francis Bach. Explicit regularization in overparametrized models via noise injection. ar Xiv preprint ar Xiv:2206.04613, 2022. 5 Weijie Su, Stephen Boyd, and Emmanuel Candes. A differential equation for modeling nesterov s accelerated gradient method: Theory and insights. In Advances in Neural Information Processing Systems, pp. 2510 2518, 2014. 15 Colin Wei and Tengyu Ma. Data-dependent sample complexity of deep neural networks via lipschitz augmentation. In Advances in Neural Information Processing Systems, pp. 9722 9733, 2019a. 2 Colin Wei and Tengyu Ma. Improved sample complexities for deep networks and robust classification via an all-layer margin. ar Xiv preprint ar Xiv:1910.04284, 2019b. 2 Dongxian Wu, Shu-Tao Xia, and Yisen Wang. Adversarial weight perturbation helps robust generalization. Advances in Neural Information Processing Systems, 33:2958 2969, 2020. 1, 2, 4, 9, 16 Lei Wu, Chao Ma, and Weinan E. How sgd selects the global minima in over-parameterized learning: A dynamical stability perspective. Advances in Neural Information Processing Systems, 31, 2018. 14 Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization. ar Xiv preprint ar Xiv:1611.03530, 2016. 1 Yang Zhao, Hao Zhang, and Xiuyuan Hu. Penalizing gradient norm for efficiently improving generalization in deep learning. ar Xiv preprint ar Xiv:2202.03599, 2022. 3 Yaowei Zheng, Richong Zhang, and Yongyi Mao. Regularizing neural networks via adversarial model perturbation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8156 8165, 2021. 1, 2 Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan, and Ting Liu. Surrogate gap minimization improves sharpness-aware training. ar Xiv preprint ar Xiv:2203.08065, 2022. 3, 5 Published as a conference paper at ICLR 2023 1 Introduction 1 2 Related Works 2 3 Notations and Assumptions 3 4 Explicit and Implicit Bias in the Full-Batch Setting 4 4.1 Worstand Ascent-direction Sharpness Have Different Explicit Biases . . . . . . . . . . . . 4 4.2 SAM Provably Decreases Worst-direction Sharpness . . . . . . . . . . . . . . . . . . . . . 5 4.3 Analysis Overview For Sharpness Reduction in Phase II of Theorem 4.5 . . . . . . . . . . . 6 5 Explicit and Implicit Biases in the Stochastic Setting 7 5.1 Stochastic Worst-, Ascentand Averagedirection Sharpness Have the Same Explicit Biases as Average Direction Sharpness . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8 5.2 Stochastic SAM Minimizes Average-direction Sharpness . . . . . . . . . . . . . . . . . . . 8 6 Conclusion 9 A Additional Related Works 14 B Experimental Details for Figure 1 15 C Additional Preliminary 15 D Well-definedness of SAM 16 E Proof Setups 19 E.1 Proofs of Theorems 5.2 and E.2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 F Properties of Limiting Map of Gradient Flow, Φ 21 G Analysis for Explicit Bias 23 G.1 A General Theorem for Explicit Bias in the Limit Case . . . . . . . . . . . . . . . . . . . . 23 G.2 Bad Limiting Regularizers May Not Capture Explicit Bias . . . . . . . . . . . . . . . . . . 25 G.3 Proof of Theorem G.6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26 G.4 Proofs of Corollary G.7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28 G.5 Limiting Regularizers For Different Notions of Sharpness . . . . . . . . . . . . . . . . . . . 28 G.6 Proof of Theorems 4.2 and 5.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 H Analysis Full-batch SAM on Quadratic Loss (Proof of Theorem 4.8) 31 H.1 Entering Invariant Set . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32 H.2 Alignment to Top Eigenvector . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33 Published as a conference paper at ICLR 2023 I Analysis for Full-batch SAM on General Loss (Proof of Theorem 4.5) 38 I.1 Phase I (Proof of Theorem I.1) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39 I.1.1 Tracking Gradient Flow . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39 I.1.2 Decreasing Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40 I.1.3 Entering Invariant Set . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41 I.2 Phase II (Proof of Theorem I.3) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 48 I.2.1 Alignment to Top Eigenvector . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 48 I.2.2 Tracking Riemannian Gradient Flow . . . . . . . . . . . . . . . . . . . . . . . . . . 54 I.3 Proof of Theorem 4.5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 55 I.4 Proofs of Corollaries 4.6 and 4.7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 55 I.5 derivations for Section 4.3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 56 J Analysis for 1-SAM (Proof of Theorem 5.4) 57 J.1 Phase I (Proof of Theorem J.1) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 58 J.1.1 Tracking Gradient Flow . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 58 J.1.2 Decreasing Loss . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 59 J.2 Phase II (Proof of Theorem J.2) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 66 J.2.1 Convergence Near Manifold . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 67 J.2.2 Tracking Riemannian Gradient Flow . . . . . . . . . . . . . . . . . . . . . . . . . . 67 J.3 Proof of Theorem 5.4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 68 J.4 Proofs of Corollaries 5.6 and 5.7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 68 J.5 Other Omitted Proofs for 1-SAM . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 68 K Technical Lemmas 69 L Omitted Proofs on Continuous Approximation 73 A ADDITIONAL RELATED WORKS Implicit Bias of Sharpness Minimization. Recent theoretical works (Blanc et al., 2019; Damian et al., 2021; Li et al., 2021) show that SGD with label noise implicitly biased toward local minimizers with a smaller trace of Hessian under the assumption that the minimizers locally connect as a manifold. Arora et al. (2022) show that normalized GD implicitly penalizes the largest eigenvalue of the Hessian. Ma et al. (2022) argues that such sharpness reduction phenomena can also be caused by a multi-scale loss landscape. Lyu et al. (2022) show that GD with weight decay on a scale invariant loss function implicitly decreases penalize the spherical sharpness, i.e., the largest eigenvalue of the Hessian evaluated at the normalized parameter. Another line of works study the sharpness minimization effect of large learning rate assuming the (stochastic) gradient descent converges in the end of training, where the analysis is mainly based on linear stability (Wu et al., 2018; Cohen et al., 2021; Ma et al., 2021; Cohen et al., 2022). Recent theoretical analysis (Damian et al., 2022; Li et al., 2022) show that the sharpness minimization effect of large learning rate in gradient descent do not necessarily rely on the convergence assumption and linear stability via a four-phase characterization of the dynamics at the so-called Edge of Stability regime (Cohen et al., 2021). Comparison with concurrent work Bartlett et al. (2022). Bartlett et al. (2022) prove that on quadratic loss, the iterate of SAM (Equation 11) and its gradient converges to the top eigenvector of Hessian, which is almost the same as our Theorem 4.8. Assuming such alignment for a general loss, the work of Bartlett et al. Published as a conference paper at ICLR 2023 (2022) shows that the largest eigenvalue of Hessian decreases in the next step. This paper also proves such a Hessian-gradient alignment for general loss functions (Lemma I.19) and an end-to-end theorem showing that the largest eigenvalue of Hessian and worst-direction sharpness decrease along the trajectory of SAM (Theorem 4.5), which are not shown in Bartlett et al. (2022). Moreover, this paper also characterize implicit bias of stochastic SAM with batch size 1, which is minimizing the average-direction sharpness, while Bartlett et al. (2022) only considers the deterministic case. Comparison with Arora et al. (2022). Our proof uses a similar framework as Arora et al. (2022). However, our analysis has its own difficulty for the following reasons. First, Arora et al. (2022) only deal with the deterministic case, while our analysis extends to stochastic SAM as well (Section 5). Second, our analysis for the deterministic case is different from that of Arora et al. (2022) in the following two aspects. First, the alignment analysis is more complicated because we have two hyperparameters,learning rate η and perturbation radius ρ, while Arora et al. (2022) only needs to deal with one hyperparameter, learning rate η. Second, the mechanism of penalizing worst-direction sharpness is different, which can be seen from the dependency of the sharpness-reduction rate over learning rate η. In Arora et al. (2022), normalized GD reduces the sharpness via a second-order effect of GD and thus the sharpness is reduced by O(η2) per step. In our analysis, for fixed small perturbation radius ρ, the sharpness is reduced by O(ρ2η) per step, which is linear in η. Analyzing Discrete-time Dynamics via Continuous-time Approaches. There is a long line of research that shows the trajectory of stochastic discrete iterations with decaying step size eventually tracks the solution of some ODE (see Kushner et al. (2003); Borkar et al. (2009); Duchi et al. (2018) and the reference therein). However, those results mainly focus on the convergence property of the stochastic iterates (e.g., convergence to stationary points), while we are interested in characterizing the trajectory especially when the process is running for a long time even after the iterate reaches the neighborhood of the manifold of stationary points. Recently there has been an effort of modeling the discrete-time trajectory of (stochastic) gradient methods by continuous-time approximation (Su et al., 2014; Mandt et al., 2017; Li et al., 2017; 2019). Notably, Li et al. (2019) presents a general and rigorous mathematical framework to prove such continuous-time approximation. More specifically, Li et al. (2019) proves for various stochastic gradient-based methods, the discrete-time weakly converges to the continuous-time one when LR η 0 in Θ(1/η) steps. The main difference between our results with these results (e.g., Theorem 9 in Li et al. (2019)) is that we focus on a much longer training regime, i.e., T = Θ(η 1ρ 2) steps where the previous continuous-time approximation results no longer holds throughout the entire training. As a result, their continuous approximation is only equivalent to the Phase I dynamics in our Theorems 4.5 and 5.4 and cannot capture the dynamics of SAM in Phase II, when the sharpness-reduction implicit bias happens. The latter requires a more fine-grained analysis to capture the effects of higher-order terms in η and ρ in SAM Equation 3. B EXPERIMENTAL DETAILS FOR FIGURE 1 In Figure 1, we choose F1(x) = x2 1 +6x2 2 +8 and F2(x) = 4(1 x1)2 +(1 x2)2 +1. The loss L has a zero loss manifold {x = 0} and the eigenvalues of its Hessian on the manifold are F1(x) and F2(x) with F1(x) 8 > 6 F2(x) on [0, 1]2. The loss L has a zero loss manifold {x3 = x4 = 0} of codimension M = 2 and the two non-zero eigenvalues of 2L of any point x on the manifold are λ1( 2L(x)) = F1(x1, x2) and λ2( 2L(x)) = F2(x1, x2). As our theory predicts, 1. Full-batch SAM (Equation 3) finds the minimizer with the smallest top eigenvalue F1(x), which is x1 = 0, x2 = 0, x3 = 0, x4 = 0; 2. GD on ascent-direction loss LAsc ρ (2) finds the minimizer with the smallest bottom eigenvalue, F2(x), which is x1 = 1, x2 = 1, x3 = 0, x4 = 0; 3. Stochastic SAM (Equation 13) (with L0(x, y) = F1(x)y2 0, L1(x, y) = F2(x)y2 1) finds the minimizer with smallest trace of Hessian, which is x1 = 4/5, x2 = 1/7, x3 = 0, x4 = 0. C ADDITIONAL PRELIMINARY In this section, we introduce some additional notations and clarification before the proof. We will first give the detailed definition of differentiable submanifold. Published as a conference paper at ICLR 2023 Definition C.1 (Differentiable Submanifold of RD). We call a subset Γ RD a Ck submanifold of RD if and only if for every x Γ, there exists a open neighborhood U of x and an invertible Ck map ψ : U RD, such that ψ(Γ U) = (Rn {0}) ψ(U). Necessity of Manifold Assumption. The connectivity of the set of local minimizers implied by the manifold assumption above allows us to take limits of perturbation radius ρ 0 while still yield interesting and insightful implicit bias results in the end-to-end analysis. So far almost all analysis of implicit bias for general model parameterizations relies on Taylor expansion, e.g. Blanc et al. (2019); Damian et al. (2021); Li et al. (2021); Arora et al. (2022), so does the derivation of the SAM algorithm Foret et al. (2020); Wu et al. (2020). Thus it s crucial to consider small perturbation size ρ. On the contrary, if the set of global minimizers are a set of discrete points, then with small perturbation radius ρ, implicit bias of optimizers is not sufficient to drive the iterate from global minimum to the other one. Implicit versus Explicit Bias. If an algorithm or optimizer has a bias towards certain type of global/local minima of the loss over other minima of the loss, and this bias is not encoded in the loss function, then we call such bias an implicit bias. On the other hand, a bias emerges as solely a consequence of successfully minimizing certain regularized loss regardless of the optimizers (as long as the optimzers minimize the loss), we say such bias is an explicit bias of the regularized loss (or the regularizer). As a concrete example, we will prove that full-batch SAM (Equation 3) prefers local minima with certain sharpness property. The bias stems from the particular update rule of full-batch SAM (Equation 3), and not all optimizers for the intended target loss function LAsc ρ (Equation 2) has this bias. Therefore, it s considered as an implicit bias. As an example for explicit bias, all optimizers minimizing a loss combined with ℓ2 regularization will prefer model with smaller parameter norm and this is considered as an explicit bias of ℓ2 regularization. Usage of O( ) Notation: Our analysis assumes small η and ρ while treating all other problem-dependent parameters as constants, such as the dimension of parameter space and the maximum possible value of derivatives (of different orders) of loss function L and the limit map Φ. In O( ), Ω( ), o( ), ω( ), Θ( ), we hide all the dependency related to the problem, e.g., the (unique) initialization xinit, the manifold Γ, compact set U in Theorem 4.2, and the continuous time T3 in Theorems 4.5 and 5.4, and only keep the dependency on ρ and η. For example, O(f(ρ)) is a placeholder for some function g(ρ) such that there exists problem-dependent constant C > 0, ρ > 0, |g(ρ)| C|f(ρ)|. In informal equations such as Equation 31 in the proof sketch section, we are a bit more sloppy and hide dependency on x(t) in O( ) notation as well. But these will be formally dealt with in the proofs. Ill-definedness of SAM with Zero Gradient. The update rule of SAM (Equations 3 and 13) is ill-defined when the gradient is zero. However, our analysis in Appendix D shows that when the stationary point of loss L, {x | L(x) = 0}, is a zero-measure set, for any perturbation radius ρ, except for countably many learning rates, full-batch SAM is well-defined for almost all initialization and all steps (Theorem D.1). A similar result is shown for stochastic SAM if the stationary points of each stochastic loss form a zero-measure set (Theorem D.2). Thus SAM is generically well-defined. For the sake of rigorousness, when SAM encountering zero gradients, we modify the algorithm via replacing the ill-defined normalized gradient by an arbitrary vector with unit norm and our analysis for implicit bias of SAM still holds. D WELL-DEFINEDNESS OF SAM In this section, we discuss the well-definedness of SAM. When L(x) = 0, SAM (Equation 3) is not welldefined, because the normalized gradient L(x) L(x) 2 is not well-defined. The main result of this section are Theorems D.1 and D.2, which say that (stochastic) SAM starting from random initialization only has zero probability to reach points that SAM is undefined (i.e., points with zero gradient), for all except countably many learning rates. These results follow from Theorem D.3, which is a more general theorem also applicable to other discrete update rules as well, like SGD. Note results in this section does not rely on the manifold assumption, i.e., Assumption 3.2. We end this section with a concrete example where SAM is undefined with constant probability, suggesting that the exclusion of countably many learning rates are necessary in Theorems D.1 and D.2. Theorem D.1. Consider any C2 loss L with zero-measure stationary set {x | L(x) = 0}. For every ρ > 0, except countably many learning rates, for almost all initialization and all t, the iterate of full-batch SAM (Equation 3) x(t) has non-zero gradient and is thus well-defined. Published as a conference paper at ICLR 2023 Theorem D.2. Consider any C2 losses {Lk}M k=1 with zero-measure stationary set {x | Lk(x) = 0} for each k [M]. For every ρ > 0, except countably many learning rates η, for almost all initialization and all t, with probability one of the randomness of the algorithm, the iterate of stochastic SAM (Equation 13) x(t) has non-zero gradient and is thus well-defined. 5 Before present our main theorem (Theorem D.3), we need to introduce some notations first. For a map F mapping from RD \ Z RD, we define that Fη : RD \ Z RD as Fη(x) x ηF(x) for any η R+. Given a sequence of functions {F n} n=1, we define F n η (x) x ηF n(x), for any x RD. We further define that F n η(x) F n η (F n 1 η (x)) for any n 1 and that F 0 η(x) = x. Theorem D.3. Let Z be a closed subset of RD with zero Lebesgue measure and µ be any probability measure on RD that is absolutely continuous to the Lesbegue measure. For any sequence of C1 functions F n : RD \ Z RD, n N+, the following claim holds for all except countably many η R+: µ {x RD | n N, F n η(x) Z and 0 i n 1, F i η(x) / Z} = 0. In other words, for almost all η (except countably many positive numbers), iteration x(t + 1) = x(t) ηF(x(t)) = F t η(x(0)) will not enter Z almost surely, provided that x(0) is sampled from µ. Theorem D.1 and Theorem D.2 follows immediately from Theorem D.3. Proof of Theorem D.1. Let F(x) = L(x + ρ L(x) L(x) 2 ) and Z = {x RD | L(x) = 0}. We can easily check F is C1 on RD \ Z and by assumption Z is a zero-measure set. Applying Theorem D.3 with F n F for all n N+, we get the desired results. Proof of Theorem D.2. Let Gk(x) = L(x + ρ Lk(x) Lk(x) 2 ) and Z = M k=1{x RD | Lk(x) = 0}. We can easily check Fk is C1 on RD \ Z and by assumption Z is a zero-measure set. Applying Theorem D.3 with F n = Gkn for all n N+ where kn is the nth data/batch sampled by the algorithm, we get the desired results. Now we will turn to the proof of Theorem D.3, which is based on the following two lemmas. Lemma D.4. Let Z be a closed subset of RD with zero Lebesgue measure and F : RD \ Z RD be a continuously differentiable function. Then except countably many η R+, {x Rd \Z | det( Fη(x)) = 0} is a zero-measure set under Lebesgue measure. Lemma D.5. Let Z be a closed subset of RD with zero Lebesgue measure and H : RD \ Z RD be a continuously differentiable function. If {x Rd \ Z | det( H(x)) = 0} is a zero-measure set, then for any zero-measure set Z , H 1(Z ) is a zero-measure set. Proof of Theorem D.3. It suffices to prove that for every N N+, at most for countably many η: µ {x RD | F N η (x) Z and 0 i N 1, F i η(x) / Z} = 0. (15) The desired results is immediately implied by the above claim because the countable union of countable set is still countable, and countable union of zero-measure set is still zero measure. To prove Equation 15, we first introduce some notations. For any η > 0, 0 n N 1, and x RD, we define F (n+1) η (x) (F N n η ) 1(F n η (x)), where F 0 η(x) = x. We extend the definition to set in a natural way, namely F n η (S) x SF n η (x) for any S RD. Under this notation, we have that F N η (Z) = µ {x RD | F N η (x) Z and 0 i N 1, F i η(x) / Z} We will prove by induction. We claim that for each 0 n N except for countably many η R+, F n η (Z) has zero Lebesgue measure. The base case n = 0 is by trivial as Z is assumed to be zero-measure. Suppose 5Though we call Equation 13 1-SAM, but our result here applies to any batch size where Lk can be regarded as the loss for k-th possible batch and M is the number of the total number of batches. Published as a conference paper at ICLR 2023 this holds for n. By Lemma D.4, except countably many η R+, {x Rd\F n η (Z) | det( F N n 1 η (x)) = 0} is a zero-measure set. Next by Lemma D.5 if for some η R+, {x Rd\F n η (Z) | det( F N n 1 η (x)) = 0} is a zero-measure set, then F n 1 η (Z) = (F N n 1 η ) 1(F n η (Z)) is a zero-measure set. Then by induc- tion, we know that except countably many η R+, for all integer 0 n N, F n η (Z) is zero-measure. Since µ is absolutely continuous to Lebesgue measure, µ(F N η (Z)) = 0. We end this section with the proofs of Lemmas D.4 and D.5. Proof of Lemma D.4. We use λi(x) to denote that the real part of the ith eigenvalue of the matrix F(x) in the descending order. Since F(x) is continuous in x, λi(x) is continuous in x as well, for any i [D], and thus {x RD \ Z | λi(x) = 1/η} is a measurable set. Note that for a fixed i [D], for each positive integer n, let In be the set of η where µ({x RD \ Z | λi(x) = 1/η}) > 1/n, then |In| n, because |In| η In µ(({x RD \ Z | λi(x) = 1/η}) µ(({x RD \ Z | 1/λi(x) In}) 1. Therefore, there are at most countably many η R+, such that µ({x RD \ Z | λi(x) = 1/η}) > 0. Further note that det( Fη(x)) = 0 i [D], λi(x) = 1/η, we know that there are at most countably many η R+, such that µ({x RD \ Z | det( Fη(x)) = 0}) = 0. This completes the proof. Proof of Lemma D.5. Denote {x RD \ Z | det( H(x)) = 0} by Z , since det( H(x)) is continuous in x as F is C1, Z is relatively closed in RD \ Z. Since Z is a closed set, RD \ (Z Z ) is open. Thus for all x RD \ (Z Z ) with det( H(x)) = 0, there exists a open neighborhood of x, U, where for all x U, det( H(x )) = 0, since thus det( H(x)) is continuous. This further implies H is invertible on U and its inverse (H|U) 1 is differentiable on F(U). Therefore, (H|U) 1 maps any zero-measure set to a zeromeasure set. In particular, (H|U) 1(Z H(U)) is zero measure, so is (H) 1(Z ) U (H|U) 1(Z H(U)). Now for every x RD \ Z we take an open neighborhood Ux RD \ (Z Z ). Since RD is a separable metric space, the open cover of RD, {Ux}x RD\(Z Z ) has a countable subcover, {Ux}x I, where I is a countable set of RD \ (Z Z ). Therefore we have that H 1(Z ) \ (Z Z ) = H 1(Z ) (RD \ (Z Z )) = x IH 1(Z ) Ux is a zero-measure set. Thus H 1(Z ) is also zero-measure since Z , Z are both zero-measure. This completes the proof. We end this section with an example where SAM is undefined with constant probability. Theorem D.6. For any η, ρ > 0, there is a C2 loss function L : R R satisfying that (1) L has a unique stationary point and (2) the set of initialization that makes SAM with learning rate η and perturbation radius ρ to reach the unique stationary point has positive Lebesgue measure. Proof of Theorem D.6. We first consider the case with ρ = η = 1 with x2/2 + x + 1/2, for x ( , 2); x4/64 + x2/8, for x [ 2, 2]; x2/2 x + 1/2, for x (2, ). (16) We first check L is indeed C1: L(2) = L( 2) = 1/2, L (2) = L ( 2) = 1 and L (2) = L ( 2) = 1. Now we claim that for all |x(0)| > 2, x(1) = 0, which is a stationary point. Note that L is even and monotone increasing on [0, ), we have L(x)/| L(x)| = sign(x). Thus for |x(t)| > 1, it holds that |x(t) + sign(x(t)| > 2 and therefore x(t + 1) =x(t) ηL (x(t) + ρ L(x)/| L(x)|) =x(t) L (x(t) + sign(x(t))) =x(t) (x(t) + sign(x(t)) sign(x(t) + sign(x(t)))) =x(t) x(t) = 0. (17) Now we turn to the case with arbitrary positive η, ρ. It suffices to consider Lη,ρ(x) ρ ρ). We can use the calculation for ρ = η = 1 to verify for any |x| > 2ρ, L η,ρ(x + ρ sign(L η,ρ(x))) = L η,ρ(x + ρ sign(x)) = 1 η L(x/ρ + sign(x)) = x Published as a conference paper at ICLR 2023 namely x L η,ρ(x + ρ sign(L η,ρ(x))) = 0. This completes the proof. A common (but wrong) intuition here is that, for a continuously differentiable update rule, as long as the points where the update rule is ill-defined (here it means the points with zero gradient) has zero measure, then almost surely for all initialization, gradient-based optimization algorithms like SAM will not reach exactly at any stationary point. However the above example negate this intuition. The issue here is that though a differentiable map (like SAM x 7 x η L(x + ρ L(x) L(x) 2)) always maps the zero-measure set to zero-measure set, the preimage of zero-measure set is not necessarily zero-measure, as the map x 7 x η L(x+ρ L(x) L(x) 2) is not necessarily invertible. The update rule of SAM is not invertible at 0 is exactly the reason of why preimage of 0 has a positive measure. E PROOF SETUPS In this section we provide details of our proof setups, including notations and assumptions/settings. We first introduce some additional notations that will be used in the proofs. For any subset S RD, we define dist(x, S) infy S x y 2. For any d > 0 and any subset S RD, we define Sd {x RD | dist(x, S) d}. Our convention is to use K to denote a compact set and U to denote an open set. Below we restate our main assumption in the full-batch case and related notations in Section 3. Throughout the analysis, we fix our initialization as xinit, our loss function as L : RD R. Assumption 3.2. Assume loss L : RD R is C4, and there exists a C2 submanifold Γ of RD that is a (D M)-dimensional for some integer 1 M D, where for all x Γ, x is a local minimizer of L and rank( 2L(x)) = M. Notations for Full-Batch Setting: Given any point x Γ, define Px,Γ as the projection operator onto the manifold of the normal space of Γ at x and P x,Γ = ID Px,Γ. Given the loss function L, its gradient flow is denoted by mapping ϕ : RD [0, ) RD. Here, ϕ(x, τ) denotes the iterate at time τ of a gradient flow starting at x and is defined as the unique solution of ϕ(x, τ) = x R τ 0 L(ϕ(x, t))dt, x RD. We further define the limiting map of ϕ(x, ) as Φ(x) = limτ ϕ(x, τ), that is, Φ(x) denotes the convergent point of the gradient flow starting from x. For convenience, we define λi(x), vi(x) as λi( 2L(Φ(x))), vi( 2L(Φ(x))) whenever the latter is well defined. When x(t) and Γ is clear from context, we also use λi(t) := λi(x(t)), vi(t) := vi(x(t)), P t,Γ := P Φ(x(t)),Γ, Pt,Γ := PΦ(x(t)),Γ. Definition 3.3 (Attraction Set). Let U be the attraction set of Γ under gradient flow, that is, a neighborhood of Γ containing all points starting from which gradient flow w.r.t. loss L converges to some point in Γ, or mathematically, U {x RD|Φ(x) exists and Φ(x) Γ}. Below we restate the setting for stochastic loss of batch size one in Section 5. Setting 5.1. Let the total number of data be M. Let fk(x) be the model output on the k-th data where fk is a C4-smooth function and yk be the k-th label, for k = 1, . . . , M. We define the loss on the k-th data as Lk(x) = ℓ(fk(x), yk) and the total loss L = PM k=1 Lk/M, where function ℓ(y , y) is C4-smooth in y . We also assume for any y R, it holds that arg miny R ℓ(y , y) = y and that 2ℓ(y ,y) ( y )2 |y =y > 0. Finally, we denote the set of global minimizers of L with full-rank Jacobian by Γ and assume that it is non-empty, that is, Γ x RD | fk(x) = yk, k [M] and { fk(x)}M k=1 are linearly independent = . Theorem 5.2. Loss L, set Γ and integer M defined in Setting 5.1 satisfy Assumption 3.2. In our analysis, we prove our main theorems in the stochastic setting under a more general condition than Setting 5.1, which is Condition E.1 (on top of Assumption 3.2). The only usage of Setting 5.1 in the proof is Theorems 5.2 and E.2. Condition E.1. Total loss L = 1 M PM k=1 Lk. For each k [M], Lk is C4, and there exists a (D 1)- dimensional C2-submanifold of RD, Γk, where for all x Γk, x is a global minimizer of Lk, Lk(x) = 0 and rank( 2Lk(x)) = 1. Moreover, Γ = M k=1Γk for Γ defined in Assumption 3.2. Theorem E.2. Setting 5.1 implies Condition E.1. Published as a conference paper at ICLR 2023 Notations for Stochastic Setting: Since Lk is rank-1 on Γk for each k [M], we can write it as Lk(x) = Λk(x)wk(x)w k (x) for any x Γ, where wk is a continuous function on Γ with pointwise unit norm. Given the loss function Lk, its gradient flow is denoted by mapping ϕk : RD [0, ) RD. Here, ϕk(x, τ) denotes the iterate at time τ of a gradient flow starting at x and is defined as the unique solution of ϕk(x, τ) = x R τ 0 Lk(ϕk(x, t))dt, x RD. We further define the limiting map Φk as Φk(x) = limτ ϕk(x, τ), that is, Φk(x) denotes the convergent point of the gradient flow starting from x. Similar to Definition 3.3, we define Uk = {x RD|Φ(x) exists and Φk(x) Γk} be the attraction set of Γi. We have that each Uk is open and Φk is C 2 on Uk by Lemma B.15 in Arora et al. (2022). Definition E.3. A function L is µ-PL in a set U iff x U, L(x) 2 2 2µ(L(x) infx U L(x)). Definition E.4. The spectral 2-norm of a k-order tensor Xi1,...,ik Rd1 ... dk is defined as X 2 = max xi Rdi, xi 2=1 X[x1, ..., xk]. Lemma E.5 (Arora et al. (2022) Lemma B.2). Given any compact set K Γ, there exist r(K), µ(K), (K) R+ such that 1. Kr(K) Γ is compact. 2. Kr(K) U ( k [M]Uk). 3. L is µ(K)-PL on Kr(K). 4. infx Kr(K)(λ1( 2L(x)) λ2( 2L(x))) (K) > 0. 5. infx Kr(K) λM( 2L(x)) µ(K) > 0. 6. infx Kr(K) λ1( 2Lk(x)) µ(K) > 0. Given compact set K Γ, we further define ζ(K) = sup x Kr(K) 2L(x) 2, ν(K) = sup x Kr(K) 3L(x) 2, Υ(K) = sup x Kr(K) 4L(x) 2, ξ(K) = sup x Kr(K) 2Φ(x) 2, χ(K) = sup x,y Kr(K) 2Φ(x) 2Φ(y) 2 Similarly, we use notations like ζk(K), νk(K), Υk(K), ξk(K), χk(K) to denote the counterpart of the above quantities defined for stochastic loss Lk and its limiting map Φk for k [M]. Lemma E.6 (Arora et al. (2022), Lemma B.5 and B.7). Given any compact subset K Γ, let r(K) be defined in Lemma E.5, there exist 0 < h(K) < r(K) such that 1. sup x Kh(K) L(x) inf x Kh(K) L(x) µ(K)ρ2(K) 2. x Kh(K), Φ(x) Kr(K)/2. 3. x Kh(K), x Φ(x) 2 8µ(K)2 4. The whole segment xΦ(x) lies in Kr(K), so does xΦk(x), for any k [D]. The proof of the lemmas above can be found in Arora et al. (2022). Readers should note that although Arora et al. (2022) only prove these lemmas when K is a special compact set (the trajectory of an ODE), all the proof does not use any property of K other than it is a compact subset of Γ, and thus our Lemmas E.5 and E.6 hold for general compact subsets of Γ. In the rest part of the appendix, for convenience we will drop the dependency on K in various constants when there is no ambiguity. E.1 PROOFS OF THEOREMS 5.2 AND E.2 Proof of Theorem 5.2. Define F : RD RM as [F(x)]k = fk(x), k [M]. Let Tx span({ fk(x)}M k=1) and T x be the orthogonal complement of Tx in RD. Now we apply implicit function theorem on F at each x Γ. Without loss of generality (e.g. by rotating the coordinate system), we can assume that x = 0, Tx = RD M {0}, and that T x = {0} RM. Implicit function theorem ensures that there are two open sets 0 U RD M and 0 V RM and an invertible C4 map g : U V such that F 1(Y ) (U V ) = {(u, g(u)) | u U}, where Y [y1, . . . , y M] RM. Moreover, { fk(x)}M k=1 is linearly independent for every x U V . Thus by definition of Γ, it holds that Γ (U V ) = F 1(y) (U V ) = {(u, g(u)) | u U}. Now for Published as a conference paper at ICLR 2023 x = (u, v) U V , we define ψ : U V RD by ψ(u, v) (u, v g(u)). We can check that ψ is C4 and ψ(Γ (U V )) = {(u, v g(u)) | v = g(u), u U)} = {(u, 0) | u U)} = U {0} = (RD M {0}) ψ(U). This proves that Γ is a C4 submanifold of RD of dimension D M. (c.f. Definition C.1) Since arg miny R ℓ(y , y) = y for any y R, it is clear that x Γ, x is a global minimizer of L. Finally we check the rank of Hessian of loss L. Note that for any x Γ, 2Lk(x) = 2ℓ(y ,yk) ( y )2 |y =yk fk(x)( fk(x)) and that 2ℓ(y ,yk) ( y )2 |y =yk > 0, rank( 2L(x)) = rank( F(x)) = M. This completes the proof. Proof of Theorem E.2. 1. L = 1 M PM k=1 Lk by definition. 2. k [M], Lk(x) = ℓ(fk(x), yk) is C4 as ℓand fk are both C4. 3. For any x Γ, by Lemma 5.5, we have fk(x) = 0. Then there exists an open neighborhood Vk such that Γ Vk and fk(x) = 0 for any x Vk, k [M]. Then applying implicit function theorem as in the proof of Theorem 5.2, for any k M there exists a (D 1)-dimensional C4-manifold Γ k Vk, such that for any x V , fk(x ) = yk if and only if x Γ k. As for any x Γ Vk, fk(x ) = yk, we can infer that Γ Γ k. Then Γ M k=1Γk. 4. For any x Γk, we have fk(x) = yk, which implies Lk(x) = 0. Also as x V , fk(x) = 0. By Lemma J.15, we have rank( 2L(x)) = 1. F PROPERTIES OF LIMITING MAP OF GRADIENT FLOW, Φ In our analysis, the property of Φ will be heavily used. In this section, we will recap some related lemmas from Arora et al. (2022), and then introduce some new lemmas for the stochastic setting with batch size one. Lemma F.1 (Arora et al. (2022) Lemma B.6). Given any compact set K Γ, for any x Kh, 2(L(x) L(Φ(x))) Lemma F.2. Given any compact set K Γ, for any x Kh, L(x) 2 ζ x Φ(x) 2 ζ 2(L(x) L(Φ(x))) Proof of Lemma F.2. The first inequality is by Lemma E.5 and Taylor Expansion. The second inequality is by Lemma F.1. Lemma F.3 (Arora et al. (2022) Lemmas B.16 and B.22). Φ(x) L (x) = 0, x U; Φ (x) 2L (x) L (x) = 2Φ (x) [ L (x) , L (x)] , x U; Φ (x) 2( L)(x)[v1, v1] = P x,Γ (λ1( 2(L(x)))), x Γ. Lemma F.4 (Arora et al. (2022) Lemmas B.8 and B.9). Given any compact set K Γ, for any x Kh, P Φ(x),Γ(x Φ(x)) 2 ζν 4µ2 x Φ(x) 2 2; L (x) 2L (Φ(x)) (x Φ(x)) 2 ν 2 x Φ(x) 2 2; L (x) 2 2L (Φ(x)) (x Φ(x)) 2 1 2ν µ x Φ(x) 2; L (x) L (x) = 2L (Φ(x)) (x Φ(x)) 2L (Φ(x)) (x Φ(x)) 2 + O(ν µ x Φ(x) 2). Lemma F.5 (Li et al. (2021), Lemma 4.3). For x Γ, Φ(x) = P x,Γ, the orthogonal projection matrix onto the tangent space of Γ at x. Since d Φ(x) 2L(x) = 0. Published as a conference paper at ICLR 2023 The proof of above lemmas can be found in Arora et al. (2022); Li et al. (2021). In the following, we will first show the proof of Lemma 3.1 Proof of Lemma 3.1. Since Φ is defined the limit map of gradient flow, it holds that for any t 0, Φ(ϕ(x, t)) = Φ(x). Differentiating both sides at t = 0, we have Φ(ϕ(x, 0)) ϕ(x,t) t = 0. The proof is completed by noting that ϕ(x,t) t = L(ϕ(x, t)) by definition of ϕ. Lemma F.6. Given any compact set K Γ, for any x Kh, Φ(x) Lk(x) 2 (νk + ζkξ) x Φ(x) 2 2 Φ(x) 2Lk(x) Lk (x) Lk (x) 2 (νk + ζkξ) x Φ(x) 2 Proof of Lemma F.6. By Lemma E.6 and Taylor Expansion, Φ(x) Lk(x) 2 Φ(x) 2Lk(Φ(x))(x Φ(x)) 2 + νk x Φ(x) 2 2 Φ(Φ(x)) 2Lk(Φ(x))(x Φ(x)) 2 + νk x Φ(x) 2 2 + ζkξ x Φ(x) 2 2 = P x,Γ Φ(Φ(x)) 2Lk(Φ(x))(x Φ(x)) 2 + νk x Φ(x) 2 2 + ζkξ x Φ(x) 2 2 = (νk + ζkξ) x Φ(x) 2 2, this proves the first claim. Again by Lemma E.5 and Taylor Expansion, Φ(x) 2Lk(x) Lk (x) Lk (x) 2 Φ(x) 2Lk(Φ(x)) Lk (x) Lk (x) 2 + νk x Φ(x) 2 Φ(Φ(x)) 2Lk(Φ(x)) Lk (x) Lk (x) 2 + (νk + ζkξ) x Φ(x) 2 = (νk + ζkξ) x Φ(x) 2, this proves the second claim. Lemma F.7. Suppose x Kh and y = x η L x + ρ L(x) y x 2 η L (x) 2 + ηζρ Φ(x) Φ(y) 2 ξηρ L (x) 2 + νηρ2 + ξη2 L (x) 2 2 + ξζ2η2ρ2 ζξηρ x Φ(x) 2 + ζ2ξη2 x Φ(x) 2 2 + νηρ2 + ξζ2η2ρ2 Proof of Lemma F.7. For sufficient small ρ, x + ρ L(x) L(x) Kr. By Taylor Expansion, y x 2 = η L x + ρ L (x) 2 η L (x) 2 + ηζρ This further implies that for sufficiently small η and ρ, xy Kr. Again by Taylor Expansion, Φ(x)(y x) 2 η Φ(x) L (x) + ρ Φ(x) 2L(x) L (x) L (x) 2 + ηρ2ν/2 . By Lemma F.3, Φ(x) L (x) = 0 and Φ (x) 2L (x) L (x) = 2Φ (x) [ L (x) , L (x)]. Hence, Φ(x)(y x) 2 ηρ L (x) 2 2Φ(x) L (x) L (x) , L (x) L (x) ξηρ L (x) 2 + ηρ2ν/2 . As xy Kr, by Taylor Expansion, Φ(y) Φ(x) 2 Φ(x)(y x) 2 + ξ y x 2 2/2 Published as a conference paper at ICLR 2023 Putting together we have Φ(x) Φ(y) 2 ξηρ L (x) 2 + ηρ2ν + ξη2 L (x) 2 2 + ξζ2η2ρ2 . Finally, by Lemma F.2, we have Φ(x) Φ(y) 2 ξηρ L (x) 2 + νηρ2 + ξη2 L (x) 2 2 + ξζ2η2ρ2 ζξηρ x Φ(x) 2 + ζ2ξη2 x Φ(x) 2 2 + νηρ2 + ξζ2η2ρ2 . This completes the proof. Lemma F.8. Suppose x Kh and y = x η Lk x + ρ Lk(x) y x 2 η Lk (x) 2 + ηζρ , Φ(x) Φ(y) 2 O(η L(x) 2 2 + ηρ L(x) 2 + ηρ2) . Proof of Lemma F.8. For sufficient small ρ, x + ρ Lk(x) Lk(x) Kr. By Taylor Expansion, y x 2 = η Lk x + ρ Lk (x) 2 η Lk (x) 2 + ηζρ . This further implies that for sufficiently small η and ρ, xy Kr. Again by Taylor Expansion, Φ(x)(y x) 2 η Φ(x) Lk (x) + ρ Φ(x) 2Lk(x) Lk (x) Lk (x) 2 + ηρ2ν/2 . We further have by Lemma F.1, Φ(x) Lk (x) Φ(Φ(x)) Lk (x) + ξ Lk (x) 2 x Φ(x) Φ(Φ(x)) 2Lk(Φ(x))(x Φ(x)) + ν x Φ(x) 2 2 + ζξ x Φ(x) 2 2 µ L(x) 2 2 + ζξ µ2 L(x) 2 2 . ρ Φ(x) 2Lk(x) Lk (x) ρ Φ(Φ(x)) 2Lk(x) Lk (x) Lk (x) 2 + ρζξ x Φ(x) 2 ρ Φ(Φ(x)) 2Lk(Φ(x)) Lk (x) Lk (x) 2 + ρζξ x Φ(x) 2 + ρν x Φ(x) 2 µ2 L(x) 2 + ρν This completes the proof. G ANALYSIS FOR EXPLICIT BIAS Throughout this section, we assume that Assumption 3.2 holds. G.1 A GENERAL THEOREM FOR EXPLICIT BIAS IN THE LIMIT CASE In this subsection we provide the proof details for section 4.1, which shows that the explicit biases of three notions of sharpness are all different, using our new mathematical tool, Theorem G.6. Notation for Regularizers. Let Rρ : RD R { } be a family of regularizers parameterized by ρ. If Rρ is not well-defined at some x, then we let Rρ(x) = . This convention will be useful when analyzing Published as a conference paper at ICLR 2023 ascent-direction sharpness RAsc ρ = LAsc ρ L which is not defined when L(x) = 0. This convention will not change the minimizers of the regularized loss. Intuitively, a regularizer should always be non-negative, but however, when far away from manifold, there are regularizers Rρ(x) of our interest that can actually be negative, e.g., RAvg ρ (x) ρ2 2DTr( 2L(x)). Therefore we make the following assumption to allow the regularizer to be mildly negative. Condition G.1. Suppose for any bounded closed set B U, there exists C > 0, such that for sufficiently small ρ, x B, Rρ(x) Cρ2. Definition 4.3 (Limiting Regularizer). We define the limiting regularizer of {Rρ} as the function6 S : Γ R, S(x) = lim ρ 0 lim r 0 inf x x 2 r Rρ(x )/ρ2. The high-level intuition is that we want to use the notion of limiting regularizer to capture the explicit bias of Rρ among the manifold of minimizers Γ as ρ 0, which is decided by the second order term in the Taylor expansion, e.g., Equation 5 and Equation 6. In other words, the hope is that whenever the regularized loss is optimized, the final solution should be in a neighborhood of minimizer x with smallest value of limiting regularizer S(x). However, such hope cannot be true without further assumptions, which motivates the following definition of good limiting regularizer. Definition G.2 (Good Limiting Regularizer). We say the limiting regularizer S of {Rρ} is good around some x Γ, if S is non-negative and continuous at x and that there is an open set Vx containing x , such that for any C > 0, infx : x x 2 Cρ Rρ(x )/ρ2 converges uniformly to S(x) in for all x Γ Vx as ρ 0. In other words, a good limiting regularizer satisfy that for any C, ϵ > 0, there is some ρx > 0, x Γ Vx and ρ ρx , S(x) inf x x 2 C ρ Rρ(x )/ρ2 < ϵ. We say the limiting regularizer S is good on Γ, if S is good around every point x Γ. In such case we also say Rρ admits S as a good limiting regularizer on Γ. The intuition of the concept of a good limiting regularizer is that, the value of the regularizer should not drop too fast when moving away from a minimizer x in its O(ρ) neighborhood. If so, the minimizer of the regularized loss may be Ω(ρ) away from any minimizer to reduce the regularizer at the cost of increasing the original loss, which makes the limiting regularizer unable to capture the explicit bias of the regularizer. (See Appendix G.2 for a counter example) We emphasize that the conditions of good limiting regularizer is natural and covers a large family of regularizers, including worst-, ascentand average-direction sharpness. See Theorems G.3 to G.5 below. Theorem G.3. Worst-direction sharpness RMax ρ admits λ1( 2L( ))/2 as a good limiting regularizer on Γ and satisfies Condition G.1. Theorem G.4. Ascent-direction sharpness RAsc ρ admits λM( 2L( ))/2 as a good limiting regularizer on Γ and satisfies Condition G.1. Theorem G.5. Average-direction sharpness RAvg ρ admits Tr( 2L( ))/(2D) as a good limiting regularizer on Γ and satisfies Condition G.1. Next we present the main mathematical tool to analyze the explicit bias of regularizers admitting good limiting regularizers, Theorem G.6. Theorem G.6. Let U be any bounded open set such that its closure U U and U Γ = U Γ. Then for any family of parametrized regularizers {Rρ} admitting a good limiting regularizer S(x) on Γ and satisfying Condition G.1, for sufficiently small ρ, it holds that inf x U L(x) + Rρ(x) inf x U L(x) ρ2 inf x U Γ S(x) o(ρ2). Moreover, for sufficiently small ρ, it holds uniformly for all u U that L(u) + Rρ(u) inf x U (L(x) + Rρ(x)) + O(ρ2) = Rρ(u)/ρ2 inf x U Γ S(x) o(1). 6Here we implicitly assume the zeroth and first order term varnishes, which holds for all three sharpness notions. Published as a conference paper at ICLR 2023 Theorem G.6 says that minimizing the regularized loss L(u) + Rρ(u) is not very different from minimizing the original loss L(u) and the regularizer Rρ(u) respectively. To see this, we define the following optimality gaps A(u) L(u) + Rρ(u) inf x U (L(x) + Rρ(x)) 0 B(u) L(u) inf x U L(x) 0 C(u) Rρ(u)/ρ2 inf x U Γ S(x), and Theorem G.6 implies that A(u) B(u) ρ2C(u) = o(ρ2). Moreover, A(u), B(u) are non-negative by definition, and C(u) o(1) are almost non-negative, whenever A(u) is O(ρ2)-approximately optimized. For the applications we are interested in in this paper, the good limiting regularizer S can be continuously extended to the entire space RD. In such a case, the third optimality gap has an approximate alternative form which doesn t involve Rρ, namely S(u) infx U Γ S(x). Corollary G.7 shows minimizing regularized loss L(u) + Rρ(u) is equivalent to minimizing the limiting regularizer, S(u) around the manifold of local minimizer, Γ. Corollary G.7. Under the setting of Theorem G.6, let S be an continuous extension of S to Rd. For any optimality gap > 0, there is a function ϵ : R+ R+ with limρ 0 ϵ(ρ) = 0, such that for all sufficiently small ρ > 0 and all u U satisfying that L(u) + Rρ(u) inf x U L(x) + Rρ(x) ρ2, it holds that L(u) infx U L(x) ( + ϵ(ρ))ρ2 and that S(u) inf x U Γ S(x) [ ϵ(ρ), + ϵ(ρ)]. G.2 BAD LIMITING REGULARIZERS MAY NOT CAPTURE EXPLICIT BIAS In this subsection, we provide an example where a bad limiting regularizer cannot capture the explicit bias of regularizer when ρ 0, to justify the necessity of Definition G.2. Here a bad limiting regularizer is a limiting regularizer which is not good. Consider choosing Rρ(x) = L(x + ρe) L(x) with e = 1 as a fixed unit vector. We will show minimizing the regularized loss L(x) + Rρ(x) does not imply minimizing the limiting regularizer of Rρ(x) on the manifold. By Definition 4.3 and the continuity of Rρ, the limiting regularizer S of Rρ is x Γ, S(x) = lim ρ 0 lim r 0 inf x x 2 r Rρ(x )/ρ2 = lim ρ 0 Rρ(x)/ρ2 = 2L(x)[e, e] 0. However, for any x Γ, we can choose x = x ρe, then L(x ) + Rρ(x ) = L(x + ρe) = L(x) = 0. Therefore, no matter how small ρ is, minimizing L(x) + Rρ(x) can return a solution which is ρ-close to any point point of Γ. In other words, the explicit bias of minimizing L(x) + Rρ(x) is trivial and thus is not equivalent to minimizing the limiting regularizer S on the manifold Γ. The reason behind the inefficacy of the limiting regularizer S in explaining the explicit bias of Rρ is that S(x) is not a good limiting regularizer for any x Γ satisfying S(x) > 0. To be more concrete, choose C = 1 and ϵ = S(x)/2 in Definition G.2. For any x Γ and sufficiently small ρ > 0, considering x = x ρe1, by Taylor Expansion, Rρ(x ) = L(x + ρe) L(x ) = ρ L(x ), e + ρ2 2L(x )[e, e] + o(ρ2) = ρ 2L(x)(x x), e + ρ2 2L(x )[e, e] + o(ρ2) = ρ2 2L(x)[e, e] + ρ2 2L(x )[e, e] + o(ρ2) = ρ2e T ( 2L(x ) 2L(x))e + o(ρ2) = o(ρ2) This implies inf x x 2 Cρ Rρ(x ) Rρ(x1) = o(ρ2). Hence, S(x) inf x x 2 Cρ Rρ(x )/ρ2 S(x) o(1) > S(x)/2 = ϵ. Published as a conference paper at ICLR 2023 G.3 PROOF OF THEOREM G.6 This subsection aims to prove Theorem G.6. We start with a few lemmas that will be used later. Lemma G.8. Γ = U Γ. Proof of Lemma G.8. For any point x U Γ, there exists {xk} k=1 Γ such that limk xk = x. Since x U and Φ is continuous in U, it holds that Φ is continuous at x, thus limk Φ(xk) = Φ(x) Γ. However Φ(xk) = xk because xk Γ, k. Thus we know x = Φ(x) Γ. Hence U Γ Γ. The other side is clear because Γ U and Γ Γ. Lemma G.9. Let U be any bounded open set such that its closure U U. If U Γ U Γ, then U Γ = U Γ. Proof of Lemma G.9. By Lemma G.8, it holds that U Γ = U U Γ = U Γ. Note that U Γ U , U Γ Γ, we have that U Γ U Γ = U Γ, which completes the proof. Lemma G.10. Let U be any bounded open set such that its closure U U and U Γ U Γ. Then for all h2 > 0, ρ0 > 0 if x U , dist(x, Γ) ρ0 dist(x, U Γ) h2. Proof of Lemma G.10. We will prove by contradiction. Suppose there exists h2 > 0 and {xk} k=1 U , such that limk dist(xk, Γ) = 0 but k > 0, dist(xk, U Γ) h2. Since U is bounded, U is compact and thus {xk} k=1 has at least one accumulate point x in U U. Since U is the attraction set of Γ under gradient flow, we know that Φ(x ) Γ. Now we claim x Γ. This is because limk dist(xk, Γ) = 0 and thus there exists a sequence of points on Γ, {yk} k=1, where limk xk yk = 0. Thus we have that x = limk yk = limk Φ(yk) = Φ(x ), where the last step we used that x U and Φ is continuous on U. By the definition of U, x U Φ(x ) Γ, thus x Γ. Then we would have x U Γ, which is contradictory to dist(xk, U Γ) dist(xk, U Γ) h2, k > 0. This completes the proof. Lemma G.11. Let U be any bounded open set such that its closure U U and U Γ U Γ. Then for all h2 > 0, ρ1 > 0 if x U , L(x) infx U L(x) + ρ1 dist(x, U Γ) h2. Proof of Lemma G.11. We will prove by contradiction. If there exists a list of ρ1, ..., ρk, ..., such that ρk 0 and there exists xk U , such that L(xk) infx U L(x) + ρk and dist(xk, U Γ) h2. Since U is bounded, U is compact and thus {xk} k=1 has at least one accumulate point x in U U. Since L is continuous in U, L(x ) = limk L(xk) = infx U L(x). Thus x is a local minimizer of L and thus has zero gradient, which further implies that x = Φ(x ). Thus x U Γ, which is contradictory to dist(xk, U Γ) dist(xk, U Γ) h2, k > 0. This completes the proof. Lemma G.12. Let U be any bounded open set such that its closure U U and U Γ = U Γ. Suppose regularizers {Rρ} admits a limiting regularizer S on Γ, then inf x U (L(x) + Rρ(x)) ρ2 inf x U Γ S(x) + inf x U L(x) + o(ρ2). Proof of Lemma G.12. First choose sufficiently small ρ, such that ρ < h(U Γ). Choose an approximate minimizer of S(x), x0 U Γ, such that S(x0) infx U Γ S(x) + ρ2. Then by the definition of limiting regularizers (Definition 4.3) and the assumption that U is open, there exists x1 U satisfying that x1 x0 2 rρ < ρ2 and Rρ(x1)/ρ2 S(x0) ρ2. Thus, Rρ(x1) ρ2S(x0) + ρ4. As x1 x0 2 ρ2 < h and x0 U Γ. This further leads to x0x1 U Γ h. By Taylor expansion on L at x0, we would have L(x1) L(x0) + O( x0 x1 2 2) = infx U Γ L(x) + O(ρ4). Thus it holds that inf x U (L(x) + Rρ(x)) L(x1) + Rρ(x1) ρ2 inf x U Γ S(x) + inf x U L(x) + O(ρ4). This completes the proof. Lemma G.13. Let U be any bounded open set such that its closure U U and U Γ = U Γ. Suppose regularizers {Rρ} admits a good limiting regularizer S on Γ, then for all u U , u Φ(u) 2 = O(ρ) = Rρ(u) ρ2 inf x U Γ S(x) o(ρ2) . Published as a conference paper at ICLR 2023 Proof of Lemma G.13. Define r = r(K), h = h(K) as the constant in Lemma E.5 with K = U Γ. Note K is compact and by Lemma G.9, K = U Γ Γ. By Lemma E.5, we have Kr Γ is a compact set, so is Kh Γ. Since S is a good limiting regularizer for {Rρ}, by Definition G.2, for any x Kh Γ, there exists open neighborhood of x , Vx such that for any C, ϵ1 > 0, there is a ρx such that x Vx and ρ ρx , S(x) inf x x 2 C ρ Rρ(x )/ρ2 < ϵ1. Note that Kh Γ is compact, there exists a finite subset of Kh Γ, {xk}k, such that Kh Γ k Vxk. Hence for any C, ϵ1 > 0, there is some ρK = mink ρxk > 0, it holds that, x Kh Γ and ρ ρK, S(x) inf x x 2 C ρ Rρ(x )/ρ2 < ϵ1. (18) We can rewrite Equation 18 as for any C > 0, S(x) inf x x 2 C ρ Rρ(x )/ρ2 = o(1), as ρ 0. (19) As u U U, we have that Φ(u) Γ. If u Φ(u) 2 = O(ρ), then dist(u, Γ) O(ρ). By Lemma G.10, we have that dist(u, K) = o(1). This further implies dist(Φ(u), K) dist(u, K) + dist(Φ(u), u) = o(1). Hence we have that Φ(u) Kh Γ for sufficiently small ρ. Thus we can pick x = Φ(u) in Equation 19 and C sufficiently large, which yields that ρ2S(Φ(u)) inf u Φ(u) 2 O(ρ) Rρ(u ) + o(ρ2) Rρ(u) + o(ρ2), (20) where the last step is because u Φ(u) 2 = O(ρ). On the other hand, we have that S(Φ(u)) inf x U Γ S(x) o(1) . (21) as S is continuous on Γ and dist(U Γ, Φ(u)) = o(1). Combining Equations 20 and 21, we have Rρ(u) ρ2 infx U Γ S(x) o(ρ2). Proof of Theorem G.6. We will first lower bound L(x) + Rρ(x) for x U . Suppose CU is the constant in Condition G.1. Define C1 = q 2 CU +infx U Γ S(x)+1 µ . We discuss by cases. For sufficiently small ρ, 1. If x Kh, then by Lemma G.11, L(x) is lower bounded by a positive constant. 2. If x Kh and x Φ(x) 2 C1ρ, then by Lemma F.1, L(x) µ x Φ(x) 2 2 2 (CU + inf x U Γ S(x) + 1)ρ2 . This implies L(x) + Rρ(x) (infx U Γ S(x) + 1)ρ2 + infx U L(x). 3. If x Φ(x) 2 C1ρ, by Lemma G.13, Rρ(x) ρ2 infx U Γ S(x) o(ρ2), hence L(x) + Rρ(x) + o(ρ2) inf x U Γ S(x)ρ2 + inf x U L(x) . Concluding the three cases, we have inf x U (L(x) + Rρ(x)) inf x U Γ L(x) + inf x U Γ S(x)ρ2 o(ρ2) . By Lemma G.12, we have that inf x U (L(x) + Rρ(x)) ρ2 inf x U Γ S(x) + inf x U Γ L(x) + o(ρ2) . Combining the above two inequalities, we prove the main statement of Theorem G.6. Furthermore, if L(u) + Rρ(u) infx U (L(x) + Rρ(x)) + O(ρ2), then by the main statement and Condition G.1, we have that L(u) inf x U L(x) inf x U (L(x) + Rρ(x)) Rρ(u) inf x U L(x) + O(ρ2) ρ2 inf x U Γ S(x) + Cρ2 + O(ρ2) = O(ρ2) . Then by Lemma G.11, we have u (U Γ)h for sufficiently small ρ. By Lemma F.1, we have u Φ(u) 2 = O(ρ). By Lemma G.13, we have Rρ(u) ρ2 infx U Γ S(x) o(ρ2). Published as a conference paper at ICLR 2023 G.4 PROOFS OF COROLLARY G.7 Proof of Corollary G.7. Since L(u)+Rρ(u) inf x U L(x) + Rρ(x) ρ2 = O(ρ2), by Theorem G.6, we have that L(u) inf x U L(x) ( + o(1))ρ2, and Rρ(x) inf x U Γ S(x) [ o(1), + o(1)]. Thus it suffices to show Rρ(x) S(x) = o(ρ2). Since L(u) infx U L(x) ( + ϵ(ρ))ρ2 = o(1), by Lemma G.11, we know dist(x, U Γ) = o(1). Thus by Lemma F.1, x Φ(x) = o(1), which implies that ρ2S(Φ(x)) o(ρ2) Rρ(u). Since S is an continuous extension, S(x) S(Φ(x)) = S(x) S(Φ(x)) = O( x Φ(x) 2) = o(1). Thus we conclude that S(x) S(Φ(x)) infx U Γ S(x) + + o(1). On the other hand, S(x) S(Φ(x)) o(1) infx U Γ S(x) o(1), where the last step we use the fact that dist(x, U Γ) = o(1). This completes the proof. G.5 LIMITING REGULARIZERS FOR DIFFERENT NOTIONS OF SHARPNESS Proof of Theorem G.3. 1. We will first verify Condition G.1. For fixed compact set B U, as 3L(x) 2 is continuous, there exists constant ν, such that x B1, 3L(x) 2 ν. Then by Taylor Expansion, RMax ρ (x) = max v 2 1 L(x + ρv) L(x) ρ L(x), v + ρ2v T 2L(x)v/2 νρ3/6 2. Now we verify SMax(x) = λ1( 2L( ))/2 is the limiting regularizer of RMax ρ . Let x be any point in Γ, by continuity of RMax ρ , lim ρ 0 lim r 0 inf x x 2 r RMax ρ (x ) ρ2 = lim ρ 0 RMax ρ (x) ρ2 = λ1( 2L(x))/2 . 3. Finally we verify definition of good limiting regularizer, by Assumption 3.2, SMax(x) = λ1(x)/2 is nonnegative and continuous on Γ. For any x Γ, choose a sufficiently small open convex set V containing x such that x V 1, 3L(x) 2 ν. For any x V Γ, for any x satisfying that x x 2 Cρ, by Theorem K.3, RMax ρ (x ) = max v 2 1 L(x + ρv) L(x ) max v 2 1 ρ L(x ), v + ρ2v T 2L(x )v/2 νρ3/6 ρ2λ1( 2L(x ))/2 νρ3/6 ρ2λ1( 2L(x))/2 O(ρ3) . This implies inf x x 2 Cρ RMax ρ (x ) ρ2λ1( 2L(x))/2 O(ρ3). On the other hand, for any x V Γ, RMax ρ (x) = max v 2 1 L(x + ρv) L(x) max v 2 1 ρ L(x), v + ρ2v T 2L(x)v/2 + νρ3 = max v 2 1 ρ2v T 2L(x)v/2 + νρ3 = ρ2λ1( 2L(x ))/2 + O(ρ3) . This implies inf x x 2 Cρ RMax ρ (x ) ρ2λ1( 2L(x))/2 + O(ρ3). Thus, we conclude that inf x x 2 Cρ RMax ρ (x )/ρ2 λ1( 2L(x))/2 = O(ρ), x V Γ, indicating SMax is a good limiting regularizer of RMax ρ on Γ. This completes the proof. Proof of Theorem G.4. Published as a conference paper at ICLR 2023 1. We will first prove Condition G.1 holds. For any fixed compact set B U, as λ1( 2L) and 3L is continuous, there exists constant C, such that x B2, λ1( 2L) > ζ and 3L(x) < ν. Then by Taylor Expansion, RAsc ρ (x) = L(x + ρ L (x) L (x) ) L(x) ρ L (x) 2 + ρ2( L (x) L (x) )T 2L(x) L (x) L (x) /2 νρ3/6 (ζ + ν/6)ρ2. 2. Now we verify SAsc(x) = Tr( 2L( ))/2 is the limiting regularizer of RAsc ρ . Let x be any point in Γ. Let K = {x} and choose h = h(K) as in Lemma E.5. For any x Kh U , RAsc ρ (x ) = L(x + ρ L (x ) L (x ) ) L(x ) ρ L (x ) 2 + ρ2( L (x ) L (x ) )T 2L(x ) L (x ) L (x ) /2 νρ3/6 L (x ) )T 2L(Φ(x )) L (x ) L (x ) /2 νρ3/6 . By Lemma F.4, we have L(x ) L(x ) = 2L(Φ(x ))(x Φ(x )) 2L(Φ(x ))(x Φ(x )) 2 + O( ν µ x Φ(x ) 2). Hence RAsc ρ (x ) ρ2λM( 2L(Φ(x )))/2 ζρ2O( x Φ(x ) 2) νρ3/6 . This implies limρ 0 limr 0 inf x x 2 r RAsc ρ (x ) ρ2 λM( 2L(Φ(x )))/2. We now show the above inequality is in fact equality. If we choose x r = x + rv M, then by Taylor Expansion, L(x r) = L(x) + 2L(x)(x r x) + O( x r x 2) = rv M + O(r2) This implies limr 0 L(x r ) L(x r ) = v M. We also have limr 0 2L(x r) = 2L(x) and limr 0 L(x r) = 0. Putting together, lim r 0 RAsc ρ (x r) = lim r 0 L(x r + ρ L (x r) L (x r) ) L(x r) ρ L (x r) 2 + ρ2( L (x r) L (x r) )T 2L(x r) L (x r) L (x r) /2 + O(νρ3) = ρ2λM( 2L(x))/2 + O(ρ3). This implies limρ 0 limr 0 inf x x 2 r RAsc ρ (x ) ρ2 limρ 0 limr 0 RAsc ρ (x r ) ρ2 = λM( 2L(x))/2. Hence the limiting regularizer S is exactly λM( 2L( ))/2. 3. Finally we verify definition of good limiting regularizer, by Assumption 3.2, SMax(x) = λM(x)/2 is nonnegative and continuous on Γ. For any x Γ, choose a sufficiently small open convex set V containing x such that x V 1, 3L(x) 2 ν. For any x V Γ, for any x satisfying that x x 2 Cρ, RAsc ρ (x ) = L(x + ρ L (x ) L (x ) ) L(x ) ρ L (x ) 2 + ρ2( L (x ) L (x ) )T 2L(x ) L (x ) L (x ) /2 νρ3/6 L (x ) )T 2L(Φ(x )) L (x ) L (x ) /2 νρ3/6 . By Lemma F.4, we have L(x ) L(x ) = 2L(Φ(x ))(x Φ(x )) 2L(Φ(x ))(x Φ(x )) 2 + O( ν µ x Φ(x ) 2). This implies inf x x 2 Cρ RAsc ρ (x ) ρ2λM( 2L(x))/2 O(ρ3). Published as a conference paper at ICLR 2023 On the other hand, simillar to the proof in the second part, we have inf x x 2 Cρ RAsc ρ (x ) ρ2λM( 2L(x))/2 + O(ρ3). Thus, we conclude that inf x x 2 Cρ RMax ρ (x )/ρ2 λ1( 2L(x))/2 = O(ρ), x V Γ, indicating SMax is a good limiting regularizer of RMax ρ on Γ. This completes the proof. Proof of Theorem G.5. 1. We will first verify Condition G.1. For fixed compact set B U, as 3L(x) 2 is continuous, there exists constant ν, such that x B1, 3L(x) 2 ν. Then by Taylor Expansion, RAvg ρ (x) = Eg N(0,I)L(x + ρ g g )T 2L(x) g 2 g νρ3/6 . 2. Now we verify SMax(x) = Tr( 2L( ))/2D is the limiting regularizer of RAvg ρ . Let x be any point in Γ, by continuity of RAvg ρ , lim ρ 0 lim r 0 inf x x 2 r RAvg ρ (x ) ρ2 = lim ρ 0 RAvg ρ (x) ρ2 = Tr( 2L(x))/2D . 3. Finally we verify definition of good limiting regularizer, by Assumption 3.2, SAvg(x) = Tr(x)/2D is nonnegative and continuous on Γ. For any x Γ, choose a sufficiently small open convex set V containing x such that x V 1, 3L(x) 2 ν. For any x V Γ, for any x satisfying that x x 2 Cρ, by Theorem K.3, RAvg ρ (x ) = Eg N(0,I)L(x + ρ g T 2L(x ) g 2g ρ2Tr( 2L(x ))/2D νρ3/6 ρ2Tr( 2L(x))/2D O(ρ3) . This implies inf x x 2 Cρ RAvg ρ (x ) ρ2Tr( 2L(x))/2D O(ρ3). On the other hand, for any x V Γ, RAvg ρ (x) = Eg N(0,I)L(x + ρ g T 2L(x) g 2 g = Eg N(0,I)ρ2 g T 2L(x) g 2 g + νρ3 = ρ2Tr( 2L(x ))/2D + O(ρ3) . This implies inf x x 2 Cρ RAvg ρ (x ) ρ2Tr( 2L(x))/2D + O(ρ3). Thus, we conclude that inf x x 2 Cρ RAvg ρ (x )/ρ2 Tr( 2L(x))/2D = O(ρ), x V Γ, indicating SAvg is a good limiting regularizer of RAvg ρ on Γ. Theorem G.14. Stochastic worst-direction sharpness Ek[RMax k,ρ ] admits Tr( 2L( ))/2 as a good limiting regularizer on Γ and satisfies Condition G.1. Proof of Theorem G.14. By Theorem E.2, Condition E.1 holds. Easily deducted from Theorem G.3 Λk(x) is a good limiting regularizer for Rmax k,ρ on Γk. Then as Γ Γk, Λk(x) is a good limiting regularizer for Rmax k,ρ on Γ. Hence S(x) = P k Λk(x)/2M = Tr( 2L(x))/2 is a good limiting regularizer of Ek[RMax k,ρ ](x) on Γ. Published as a conference paper at ICLR 2023 Theorem G.15. Stochastic ascent-direction sharpness Ek[RAsc k,ρ] admits Tr( 2L( ))/2 as a good limiting regularizer on Γ and satisfies Condition G.1. Proof of Theorem G.15. By Theorem E.2, Condition E.1 holds. Easily deducted from Theorem G.4 Λk(x) is a good limiting regularizer for Rasc k,ρ on Γk as the codimension of Γk is 1. Then as Γ Γk, Λk(x) is a good limiting regularizer for Rmax k,ρ on Γ.Hence S(x) = P k Λk(x)/2M = Tr( 2L(x))/2 is a good limiting regularizer of Ek[RAsc k,ρ](x) on Γ. Theorem G.16. Stochastic average-direction sharpness Ek[RAvg k,ρ] admits Tr( 2L( ))/(2D) as a good limiting regularizer on Γ and satisfies Condition G.1. Proof of Theorem G.16. By definition, we know that Ek[RAvg k,ρ] = RAvg ρ . The rest follows from Theorem G.5. G.6 PROOF OF THEOREMS 4.2 AND 5.3 To end this section, we prove the two theorems presented in the main text. The readers will find the proof straight forward after we established the framework of good limiting regularizers. Proof of Theorem 4.2. Apply Corollary G.7 on Rtype. The mapping from R to good limiting regularizers Stype are characterized by Theorems G.3 to G.5. Proof of Theorem 5.3. Apply Corollary G.7 on Rtype. The mapping from R to good limiting regularizers Stype are characterized by Theorems G.14 to G.16. H ANALYSIS FULL-BATCH SAM ON QUADRATIC LOSS (PROOF OF THEOREM 4.8) The goal of this section is to prove Theorem 4.8. In this section, we use A B to indicate B A is positive semi-definite. Theorem 4.8. Suppose A is a positive definite symmetric matrix with unique top eigenvalue. Consider running full-batch SAM (Equation 3) on loss L(x) := 1 2x T Ax as in Equation 11 below. x(t + 1) = x(t) ηA x(t) + ρAx(t)/ Ax(t) 2 . (11) Then, for almost every x(0), we have x(t) converges in direction to v1(A) up to a sign flip and limt x(t) 2 = ηρλ1(A) 2 ηλ1(A) with ηλ1(A) < 1. Proof of Theorem 4.8. We first rewrite the iterate as x(t + 1) = x(t) ηAx(t) ηρ A2x(t) Define x(t) L(x(t)) ρ , and we have x(t + 1) = x(t) ηA x(t) η A2 x(t) x(t) 2 . (22) We suppose A RD D and use λi, vi to denote λi(A), vi(A). Further, we define that i=j vi(A)vi(A)T , Ij { x | P (j:D) x 2 ηλ2 j} , xi(t) x(t), vi , S {t | x(t) 2 ηλ2 1 2 ηλ1 , t > T1} . Published as a conference paper at ICLR 2023 By Lemma H.1, Ij is an invariant set for update rule Equation 22. Our proof consists of two steps. (1) Entering Invariant Set. Lemma H.2 implies that there exists constant T1 > 0, such that t > T1, P (j:D) x(t) 2 ηλ2 j (2) Alignment to Top Eigenvector. Lemmas H.10 and H.11 show that x(t) 2 and | x1(t)| converge to ηλ2 1 2 ηλ1 , which implies our final results. H.1 ENTERING INVARIANT SET In this subsection, we will prove the following three lemmas. 1. Lemma H.1 shows Ij is an invariant set for update rule (Equation 22). 2. Lemma H.2 shows that under the update rule (Equation 22), all iterates not in Ij will shrink exponentially in ℓ2 norm. 3. Lemma H.3 combines Lemmas H.1 and H.2 to show that for sufficiently large t, x(t) j Ij. Lemma H.1. For t 0, if ηλ1(A) < 1 and x(t) Ij, then x(t + 1) Ij. Proof of Lemma H.1. By (Equation 22), we have that P (j:D) x(t + 1) = (I P (j:D)ηA η P (j:D)A2 x(t) 2 )P (j:D) x(t) . Hence we have that P (j:D) x(t + 1) 2 = (I P (j:D)ηA η P (j:D)A2 x(t) 2 )P (j:D) x(t) 2 I P (j:D)ηA η P (j:D)A2 x(t) 2 2 P (j:D) x(t) 2 . Because x(t) Ij, x(t) 2 ηλ2 j 1 ηλj . This implies, I(1 ηλj η λ2 j P (j:D) x(t) 2 ) I(1 ηλj η λ2 j x(t) 2 ) I P (j:D)ηA η P (j:D)A2 Hence, I P (j:D)ηA η P (j:D)A2 2 x(t) 2 max(1, ηλj + η λ2 j P (j:D) x(t) 2 1) . It holds that P (j:D) x(t + 1) 2 max( P (j:D) x(t) 2, ηλ2 j (1 ηλj) P (j:D) x(t) 2) ηλ2 j, where the last equality is because 1 ηλj 0. This above inequality is exactly the definition of x(t+1) Ij and thus is proof is completed. Lemma H.2. For t 0, if ηλ1(A) < 1 and x(t) Ij, then P (j:D) x(t + 1) 2 max 1 ηλD η λ2 D x(t) 2 , ηλj P (j:D) x(t) 2 (23) max (1 ηλD, ηλj) P (j:D) x(t) 2 . Proof of Lemma H.2. Note that P (j:D) x(t + 1) 2 = (I P (j:D)ηA η P (j:D)A2 x(t) 2 )P (j:D) x(t) 2 P (j:D) P (j:D)ηA η P (j:D)A2 x(t) 2 2 P (j:D) x(t) 2 . As x(t) Ij, We have x(t) 2 P (j:D) x(t) 2 > ηλ2 j, hence η P (j:D)A2 x(t) 2 η P (j:D)A2 ηλ2 j P (j:D). Published as a conference paper at ICLR 2023 This implies that ηλj P (j:D) P (j:D)ηA P (j:D) P (j:D)ηA η P (j:D)A2 P (j:D) P (j:D)ηA η P (j:D)A2 x(t) 2 P (j:D)(1 ηλD) η λ2 D x(t) 2 . Hence we have that P (j:D) x(t + 1) 2 max 1 ηλD η λ2 D x(t) 2 , ηλj P (j:D) x(t) 2. max (1 ηλD, ηλj) P (j:D) x(t) 2 This completes the proof. Lemma H.3. Choosing T1 = maxj logmax(1 ηλD,ηλj) max( x(0) 2 ηλ2 j , 1) , then t T1, D > j Proof of Lemma H.3. We will prove by contradiction. Suppose j [D] and T > T1, such that x(T) Ij. By Lemma H.1, it holds that t < T, x(t) Ij. Then by Lemma H.2, P (j:D) x(T) 2 max (1 ηλD, ηλj)T P (j:D) x(0) 2 ηλ2 j, which leads to a contradiction. H.2 ALIGNMENT TO TOP EIGENVECTOR In this subsection, we prove the following lemmas towards showing that x(t) converges in direction to v1(A) up to a proper sign flip. 1. Corollary H.4 show that for almost every learning rate η and initialization xinit, x1(t) = 0, for every t 0. This condition is important because if x1(t) = 0 at some step t, then for any t t, x1(t ) will also be 0 and thus alignment is impossible. 2. Lemma H.5 shows that under update rule (Equation 22), t S t+1 S for sufficiently large t, where the definition of S is {t| x(t) 2 ηλ2 1 2 ηλ1 , t > T1}. 3. Lemma H.9, a combination of Lemmas H.6 and H.7, shows that following update rule (Equation 22), x1(t) increases for t S. 4. Lemma H.10 shows that x(t) converges to ηλ2 1 2 ηλ1 under Equation 22. 5. Lemma H.11 shows that x1(t) 2 converges to ηλ2 1 2 ηλ1 under Equation 22. We will first prove that t, x1(t) = 0 happens for almost every learning rate η and initialization xinit (Corollary H.4), using a much more general result (Theorem D.3). Corollary H.4. Except for countably many η R+, for almost all initialization xinit = x(0), it holds that for all natural number t, x1(t) = 0. Proof of Corollary H.4. Let Fn(x) F(x) A(x + ρ Ax Ax 2 ), n N+, x RD and Z = {x RD | x, v1 = 0}. We can easily check F is C1 on RD \ Z and Z is a zero-measure set. Applying Theorem D.3, we have the following corollary. Lemma H.5. For t 0, if x(t) 2 > ηλ2 1 2 ηλ1 , x(t) Ij, then x(t + 1) 2 max( ηλ2 1 2 ηλ1 η λ4 D 2λ2 1 , ηλ2 1 (1 ηλ1) x(t) 2) Published as a conference paper at ICLR 2023 Proof of Lemma H.5. Note that x(t + 1) = (I ηA η A2 x(t) 2 ) x(t) (1 ηλj) x(t) 2 ηλ2 j xj(t)vj Consider the following two cases. 1 If for any i, such that (1 ηλ1) x(t) 2 ηλ2 1 (1 ηλi) x(t) 2 ηλ2 i , then we have x(t + 1) 2 (1 ηλ1) x(t) 2 ηλ2 1 = ηλ2 1 (1 ηλ1) x(t) 2 . 2 If there exists i, such that (1 ηλ1) x(t) 2 ηλ2 1 < (1 ηλi) x(t) 2 ηλ2 i , then suppose WLOG, i is the smallest among such index. As ηλ2 i (1 ηλi) x(t) 2 < ηλ2 1 (1 ηλ1) x(t) 2 = (1 ηλ1) x(t) 2 ηλ2 1 We have ηλ2 i + (1 ηλi) x(t) 2 > ηλ2 1 (1 ηλ1) x(t) 2. Equivalently, x(t) 2 > ηλ2 1 + ηλ2 i 2 ηλ1 ηλi (24) Combining with x(t) I1 x(t) 2 ηλ2 1, we have η < λ1 λi λ2 1 . Now consider the following vertors, v(1)(t) (ηλ2 1 (1 ηλ1) x(t) 2) x(t) , v(2)(t) ((2 ηλ1 ηλi) x(t) 2 ηλ2 i ηλ2 1)P (i:D) x(t) , v(2+j)(t) ((ηλi+j 1 ηλi+j) x(t) 2 ηλ2 i+j + ηλ2 i+j 1)P (i+j:D) x(t), 1 j D i . Then we have x(t + 1) 2 = 1 x(t) 2 (1 ηλj) x(t) 2 ηλ2 j xj(t)vj 2 ηλ2 1 (1 ηλ1) x(t) 2 xj(t)vj + (1 ηλj) x(t) 2 ηλ2 j xj(t)vj 2 By assumption, we have x(t) Ij, hence we have v(1)(t) 2 = (ηλ2 1 (1 ηλ1) x(t) 2) x(t) 2 , v(2)(t) 2 η((2 ηλ1 ηλi) x(t) 2 ηλ2 i ηλ2 1)λ2 i , v(2+j)(t) 2 η((ηλi+j 1 ηλi+j) x(t) 2 ηλ2 i+j + ηλ2 i+j 1)λ2 i+j, 1 j D i . Using AM-GM inequality, we have λi+j 1λ2 i+j λ3 i+j 1 + 2λ3 i+j 3 , λ2 i+j 1λ2 i+j λ4 i+j 1 + λ4 i+j 2 . Hence v(2+j)(t) 2 η((ηλi+j 1 ηλi+j) x(t) 2 ηλ2 i+j + ηλ2 i+j 1)λ2 i+j η2 x(t) 2 λ3 i+j 1 λ3 i+j 3 + η2 λ4 i+j 1 λ4 i+j 2 , 1 j D i j=1 v(2+j)(t) 2 η2 x(t) 2 λ3 i λ3 D 3 + η2 λ4 i λ4 D 2 . Published as a conference paper at ICLR 2023 Putting together, x(t + 1) 2 1 x(t) 2 ηλ2 1 + ηλ2 i (2 ηλ1 ηλi) + η2 λ3 i λ3 D 3 (1 ηλ1) x(t) 2 η2λ2 i (λ2 i + λ2 1) 1 x(t) 2 + η2 λ4 i λ4 D 2 1 x(t) 2 ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) x(t) 2 η2λ2 i (1 2λ2 i + λ2 1) 1 x(t) 2 η2 λ4 D 2 x(t) 2 ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) x(t) 2 η2λ2 i (1 2λ2 i + λ2 1) 1 x(t) 2 η λ4 D 2λ2 1 . We further discuss three cases 1. If ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 < ηλ2 1+ηλ2 i 2 ηλ1 ηλi , we have x(t) 2 > ηλ2 1+ηλ2 i 2 ηλ1 ηλi > ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 ,then ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) x(t) 2 η2λ2 i (1 2λ2 i + λ2 1) 1 x(t) 2 η λ4 D 2λ2 1 ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) ηλ2 1 + ηλ2 i 2 ηλ1 ηλi 2λ2 i + λ2 1)2 ηλ1 ηλi ηλ2 1 + ηλ2 i η λ4 D 2λ2 1 ηλ2 1 2 ηλ1 η λ4 D 2λ2 1 . The second line is because (1 ηλ1) x(t) 2 + η2λ2 i ( 1 2λ2 i + λ2 1) 1 x(t) 2 monotonously increase w.r.t x(t) 2 when x(t) 2 > ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 . The last line is due to Lemma K.9. 2. If ηλ2 1 ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 ηλ2 1+ηλ2 i 2 ηλ1 ηλi , then ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) x(t) 2 η2λ2 i (1 2λ2 i + λ2 1) 1 x(t) 2 η λ4 D 2λ2 1 ηλ2 1 + ηλ2 i (2 ηλ1 2 2λ2 i )(1 ηλ1) η λ4 D 2λ2 1 ηλ2 1 2 ηλ1 η λ4 D 2λ2 1 . The second line is because of AM-GM inequality. The last line is due to Lemma K.11. 3. If ηλ2 1 < ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 , we have x(t) 2 < ηλ2 1 < ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 , then ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1) x(t) 2 η2λ2 i (1 2λ2 i + λ2 1) 1 x(t) 2 η λ4 D 2λ2 1 ηλ2 1 + ηλ2 i (2 ηλ1 2 3ηλi) (1 ηλ1)ηλ2 1 ηλ2 i (1 2λ2 i + λ2 1) 1 λ2 1 η λ4 D 2λ2 1 ηλ2 1 2 ηλ1 η λ4 D 2λ2 1 . The second line is because (1 ηλ1) x(t) 2 + η2λ2 i ( 1 2λ2 i + λ2 1) 1 x(t) 2 monotonously decrease w.r.t x(t) 2 when x(t) 2 < ηλi q 1 2 λ2 i +λ2 1 1 ηλ1 . The last line is due to Lemma K.10. Published as a conference paper at ICLR 2023 Lemma H.6. if x(t) 2 ηλ2 1 2 ηλ1 , it holds that | x1(t + 1)| | x1(t)| . Proof of Lemma H.6. Nota that | x1(t + 1)| = |1 ηλ1 η λ2 1 x(t) 2 || x1(t)| and that η λ2 1 x(t) 2 > 2 ηλ2 1. It follows that 1 ηλ1 η λ2 1 x(t) 2 < 1. Hence we have that | x1(t + 1)| > | x1(t)|. Lemma H.7. For any t 0, if x(t) 2 ηλ2 1 2 ηλ1 , x(t) Ij, it holds that x(t + 1) 2 ηλ2 1 (1 ηλ1) x(t) 2 . Proof of Lemma H.7. Note that x(t) 2 2 max 1 j D{|1 ηλj η λ2 j x(t) |} = η λ2 1 x(t) (1 ηλj) . The proof is completed by noting that x(t + 1) I ηA η A2 x(t) 2 2 x(t) 2. Lemma H.8. For any t 0, if x(t) 2 ηλ2 1 1 ηλ1 , it holds that x(t + 1) 2 (ηλ2 1 (1 + ηλ1) x(t) 2) v u u t| x1(t)|2 x(t) 2 + max j [2:M] |(1 ηλj) x(t) 2 ηλ2 j| ηλ2 1 (1 ηλ1) x(t) 2 ! 2 1 | x1(t)|2 Proof of Lemma H.8. We will discuss the movement along v1 and orthogonal to v1. First, P (2:D) x(t + 1) 2 = (I P (2:D)ηA η P (2:D)A2 x(t) 2 )P (2:D) x(t) 2 P (2:D) P (2:D)ηA η P (2:D)A2 x(t) 2 2 P (2:D) x(t) 2 max j [2:M]{|1 ηλj ηλ2 j x(t) 2 |} P (2:D) x(t) 2 . Second, | x1(t + 1)| = ( ηλ2 1 x(t) 2 1 + ηλ1)| x1(t)|. Hence we have that x(t + 1) 2 (ηλ2 1 (1 + ηλ1) x(t) 2) v u u t| x1(t)|2 x(t) 2 + max j [2:M] |(1 ηλj) x(t) 2 ηλ2 j| ηλ2 1 (1 ηλ1) x(t) 2 } 2 1 | x1(t)|2 Lemma H.9. For t, t S, 0 t t , then | x1(t)| | x1(t )|. Proof of Lemma H.9. For t S, by Lemma H.5, t + 1 S or t + 1 S, t + 2 S. We will discuss by case. 1. If t + 1 S, we can use Lemma H.6 to show | x1(t)| | x1(t + 1)|. 2. If t + 1 S, t + 2 S, then | x1(t + 2)| = (ηλ2 1 (1 ηλ1) x(t) 2)(ηλ2 1 (1 ηλ1) x(t + 1) 2) x(t) 2 x(t + 1) 2 | x1(t)| . Published as a conference paper at ICLR 2023 (ηλ2 1 (1 ηλ1) x(t) 2)(ηλ2 1 (1 ηλ1) x(t + 1) 2) x(t) 2 x(t + 1) 2 η2λ4 1 ηλ2 1(1 ηλ1)( x(t) 2 + x(t + 1) 2) (2ηλ1 η2λ2 1) x(t) 2 x(t + 1) 2 η2λ4 1 ηλ2 1(1 ηλ1) x(t) 2 (2ηλ1 η2λ2 1) x(t) 2 + ηλ2 1(1 ηλ1) x(t + 1) 2 , combining with Lemma H.7, we only need to prove, η2λ4 1 ηλ2 1(1 ηλ1) x(t) 2 (2ηλ1 η2λ2 1) x(t) 2 + ηλ2 1(1 ηλ1) ηλ2 1 (1 ηλ1) x(t) 2 . Through some calculation, this is equivalent to ((2 ηλ1) x(t) 2 ηλ2 1)((1 ηλ1) x(t) 2 ηλ2 1) 0 . which holds for x(t) 2 ηλ2 1 2 ηλ1 . Combining the two cases and using induction, we can get the desired result. Lemma H.10. x(t) converges to ηλ2 1 2 ηλ1 when t . Proof of Lemma H.10. By Lemma H.9, | x1(t)| increases monotonously for t S. By Lemma H.5, S is infinite. By Lemma H.2, for sufficiently large t, | x1(t)| is bounded. Combining the three facts, we know x1(t) for t S converges. Formally ϵ > 0, there exists Tϵ > 0 such that t, t S, t > t > Tϵ, x1(t ) 2 x1(t) 2 < 1 + ϵ. Then by Lemma H.5, t S, t + 1 S or t + 2 S, we will discuss by case. For t Tϵ, 1. If t + 1 S, then 1 + ϵ x1(t + 1) 2 x1(t) 2 = ηλ2 1 (1 ηλ1) x(t) 2 2. If t + 1 S and t + 2 S, then 1 + ϵ x1(t + 2) 2 = (ηλ2 1 (1 ηλ1) x(t) 2)(ηλ2 1 (1 ηλ1) x(t + 1) 2) x(t) 2 x(t + 1) 2 (ηλ2 1 (1 ηλ1) x(t) 2) ηλ2 1 (1 ηλ1) ηλ2 1 (1 ηλ1) x(t) 2 x(t) 2 (ηλ2 1 (1 ηλ1) x(t) 2) = ηλ2 1 (1 ηλ1) ηλ2 1 (1 ηλ1) x(t) 2 Here in the last inequality, we apply Lemma H.7. Concluding, x(t) 2 min ηλ2 1 2 ηλ2 1+ϵ, η2λ3 1 (2 λ1η)λ1η+ϵ , t > Tϵ, t S. As t S, t > Tϵ, we have x(t) 2 ηλ2 1 2 ηλ2 1 . Hence we have t > Tϵ, x(t) 2 min ηλ2 1 2 ηλ2 1+ϵ, η2λ3 1 (2 λ1η)λ1η+ϵ . Further by Lemma H.7, t > Tϵ + 1, x(t) 2 ηλ2 1 (1 ηλ1) min ηλ2 1 2 ηλ2 1+ϵ, η2λ3 1 (2 λ1η)λ1η+ϵ . Combining both bound, we have lim t x(t) 2 = ηλ2 1 2 ηλ1 . Lemma H.11. x1(t) 2 converges to ηλ2 1 2 ηλ1 , when t . Published as a conference paper at ICLR 2023 Proof of Lemma H.11. Notice that P (2:D) x(t + 1) 2 max |1 ηλ2 η λ2 2 x(t) 2 |, |1 ηλD η λ2 D x(t) 2 | P (2:D) x(t) 2 . When x(t) 2 > ηλ2 2 2 ηλ2 δ, 1 + δ 1 ηλ2 η λ2 2 x(t) 2 1 ηλD η λ2 D x(t) 2 1 ηλD P (2:D) x(t + 1) 2 max(1 ηλD, 1 δ) P (2:D) x(t) 2 Hence for sufficiently large t, P (2:D) x(t) 2 shrinks exponentially, showing that lim t x1(t) 2 = ηλ2 1 2 ηλ1 . I ANALYSIS FOR FULL-BATCH SAM ON GENERAL LOSS (PROOF OF THEOREM 4.5) The goal of this section is to prove the following theorem. Theorem 4.5 (Main). Let {x(t)} be the iterates of full-batch SAM (Equation 3) with x(0) = xinit U. Under Assumptions 3.2 and 4.4, for all η, ρ such that η ln(1/ρ) and ρ/η are sufficiently small, the dynamics of SAM can be characterized in the following two phases: Phase I: (Theorem I.1) Full-batch SAM (Equation 3) follows Gradient Flow with respect to L until entering an O(ηρ) neighborhood of the manifold Γ in O(ln(1/ρ)/η) steps; Phase II: (Theorem I.3) Under a mild non-degeneracy assumption (Assumption I.2) on the initial point of phase II, full-batch SAM (Equation 3) tracks the solution X of Equation 7, the Riemannian Gradient Flow with respect to the loss λ1( 2L( )) in an O(ηρ) neighborhood of manifold Γ. Quantitatively, the approximation error between the iterates x and the corresponding limiting flow X is O(η ln(1/ρ)), that is, x T3/(ηρ2) X(T3) 2 = O(η ln(1/ρ)) . Moreover, the angle between L x( T3 ηρ2 and the top eigenspace of 2L(x( T3 ηρ2 )) is O(ρ). Readers may refer to Appendix E for notation. To prove the theorem, we will separate the dynamic of SAM on general loss L to two phases. i=j λ2 i (x) vi(x), x Φ(x) 2 ηρλ2 j(x), j [M], x U, which is the length projection of x Φ(x) on button k non-zero eigenspace of 2L(Φ(x)). We will provide a fine-grained convergence bound on Rj(x). Theorem I.1 (Phase I). Let {x(t)} be the iterates defined by SAM ( Equation 3) and x(t) = xinit U, then under Assumption 3.2 there exists a positive number T1 independent of η and ρ, such that for any T 1 > T1, it holds for all η, ρ such that (η + ρ) ln(1/ηρ) is sufficiently small, we have max T1 ln(1/ηρ) ηt T 1 ln(1/ηρ) max j [M] max{Rj(x(t)), 0} = O(ηρ2) max T1 ln(1/ηρ) ηt T 1 ln(1/ηρ) Φ(x(t)) Φ(xinit) O((η + ρ) ln(1/ηρ)) Theorem I.1 implies SAM will converge to an O(ηρ) neighbor of Γ. Notice in the time frame defined by Theorem I.1, x(t) effectively operates at a local regime around Φ( T1 ln(1/ηρ)/η ), this allows us to approximate L with the quadratic Taylor expansion of L at Φ( T1 ln(1/ηρ)/η ) and prove the following theorem Theorem I.3. Towards proving Theorem I.3, we need to make one assumption about the trajectory of SAM, Assumption I.2. Assumption I.2. There exists step t, satisfying that T1 ln(1/ηρ)/η t O(ln(1/ηρ/η)), | x(t) Φ(x(t)), v1(x(t)) | Ω(ρ2) and that x(t) Φ(x(t)) 2 λ1(t)ηρ Ω(ρ2), where T1 is the constant defined in Theorem I.1. Published as a conference paper at ICLR 2023 We remark that the above assumption is very mild as we only need the above two conditions in Assumption I.2 to hold for some step in Θ(1/η) steps after Phase I ends, and since then our analysis for Phase II shows that these two conditions will hold until Phase II ends. Theorem I.3 (Phase II). Let {x(t)} be the iterates defined by SAM (Equation 3) under Assumptions 3.2 and 4.4, for all η, ρ such that η ln(1/ρ) and ρ/η is sufficiently small, further assuming that (1) maxj [M] max{Rj(x(0)), 0} = O(ηρ2), (2) Φ(x(0)) Φ(xinit) = O((η + ρ) ln(1/ηρ)), (3) | x(0) Φ(x(0)), v1(x(0)) | Ω(ρ2) and (4) x(0) Φ(x(t)) 2 λ1(0)ηρ Ω(ρ2), the iterates x(t) tracks the solution X of Equation 7. Quantitatively for t = T3/ηρ2 , we have that Φ(x(t)) X(ηρ2t) = O(η ln(1/ρ)) . Moreover, the angle between L(x(t)) and the top eigenspace of 2L(Φ(x(t))) is at most O(ρ). Quantitatively, | x(t) Φ(x(t)), v1(x(t)) | = Θ(ηρ) . max j [2:M] | x(t) Φ(x(t)), vj(x(t)) | = O(ηρ2) . In this section we will define K as {X(t) | 0 t T3} where X is the solution of Equation 7. To simplify our proof, we assume WLOG L(x) = 0 for x Γ. I.1 PHASE I (PROOF OF THEOREM I.1) Proof of Theorem I.1. The proof consists of three major parts. 1. Tracking Gradient Flow. Lemma I.4 shows the existence of step t GF = O(1/η) such that x(t GF) is in a subset of Kh and Φ(x(t GF)) is O(η + ρ) close to Φ(xinit). 2. Decreasing Loss. Lemma I.6 shows the existence of step t DEC = O(ln(1/ρ)/η) such that x(t DEC) is in O(ρ) neighbor of Γ and Φ(x(t DEC)) is O((η + ρ) ln(1/ρ)) close to Φ(xinit). 3. Entering Invariant Set. Lemmas I.11 and I.13 shows the existence of step t INV = O(ln(1/ρη)/η) such that for any t satisfying t INV t t INV + Θ(ln(1/η)/η), we have that x(t) k [M]Ik and Φ(x(t)) is O((η + ρ) ln(1/ηρ)) close to Φ(xinit). I.1.1 TRACKING GRADIENT FLOW Lemma I.4 shows that the iterates x(t) tracks gradient flow to an O(1) neighbor of Γ. Lemma I.4. Under condition of Theorem I.1, there exists t GF = O(1/η), such that the iterate x(t GF) is O(1) close to the manifold Γ and Φ(x(t GF)) is O(η + ρ) is close to Φ(xinit). Quantitatively, L(x(t GF)) µh2 32 x(t GF) Φ(x(t GF)) h/4 , Φ(x(t GF)) Φ(xinit) = O(η + ρ) . Proof of Lemma I.4. Choose C = 1 µ ζ . Since Φ(xinit) = lim T ϕ(xinit, T), there exists T > 0, such that ϕ(xinit, T) Φ(xinit) 2 Ch/2 . Note that x(t + 1) = x(t) η L(x(t) + ρ L (x(t)) L (x(t)) ) = x(t) η L(x(t)) + O(ηρ) . By Corollary L.3, let b(x) = L(x), p = η and ϵ = O(ρ), we have that the iterates x(t) tracks gradient flow ϕ(xinit, T) in O(1/η) steps. Quantitatively for t GF = T η , we have that x(t GF) ϕ(xinit, T) 2 = O(ϵ + p) = O(η + ρ) . This implies x(t GF) Kh, hence by Taylor Expansion on Φ, Φ(x(t GF)) Φ(xinit) 2 = Φ(x(t GF)) Φ(ϕ(xinit, T)) 2 O( x(t GF) ϕ(xinit, T) 2) O(η + ρ) . Published as a conference paper at ICLR 2023 This implies x(t GF) Φ(x(t GF)) 2 x(t GF) ϕ(xinit, T0) 2 + ϕ(xinit, T0) Φ(xinit) 2 + Φ(xinit) Φ(x(t GF)) 2 Ch/2 + O(η + ρ) Ch h/4 . By Taylor Expansion, we conclude that L(x(t GF)) ζ x(t GF) Φ(x(t GF)) 2 2/2 µh2 I.1.2 DECREASING LOSS Lemma I.6 shows that the iterates x(t) converges to an O(ρ) neighbor of Γ in O(ln(1/ρ)/η) steps. Lemma I.5. Under condition of Theorem I.1, if x(t) Kh and L(x(t)) 4ζρ, then we have that L(x(t + 1)) decreases with respect to L(x(t)), quantitatively, we have that L(x(t + 1)) L(x(t))(1 ηµ/8) . Moreover the movement of the projection of the iterates on the manifold is bounded, quantitatively, we have that Φ(x(t + 1)) Φ(x(t)) O(η2) . Proof of Lemma I.5. As x(t) Kh and L is µ-PL in Kh, we have L(x(t)) 0. As x(t) Kh, by Lemma F.7 and Taylor Expansion, we have x(t)x(t + 1) = O(η). hence for sufficiently small η, x(t)x(t + 1) Kr. Using similar argument, the segment from x(t) to x(t) + ρ L(x(t)) L(x(t)) is in Kr. Then by Taylor Expansion on L, L(x(t + 1)) = L(x(t) η L x(t) + ρ L (x(t)) L(x(t)) η L (x(t)) , L x(t) + ρ L (x(t)) + ζη2 L x(t) + ρ L(x(t)) By Taylor Expansion on L, we have that L x(t) + ρ L (x(t)) L (x(t)) ζρ . After plugging in Equation 25, we have that L(x(t + 1)) L(x(t)) η L (x(t)) 2 + ηζρ L (x(t)) + ζη2 L (x(t)) 2 + ζ3η2ρ2 . (26) As L(x(t)) 4ζρ, we have that the following term is bounded. ζη2 L (x(t)) 2 1 2η L (x(t)) 2 , ηζρ L (x(t)) 1 4η L (x(t)) 2 , ζ3η2ρ ζ2ηρ2 1 16η L (x(t)) 2 . After plugging in Equation 26, by Lemma F.2, L(x(t + 1)) L(x(t)) 1 16η L (x(t)) 2 L(x(t))(1 ηµ/8) . As x(t) Kh, by Taylor Expansion, we have L (x(t)) ζh . Hence by Lemma F.7 and Taylor Expansion, Φ(x(t + 1)) Φ(x(t)) ξηρ L (x) 2 + νηρ2 + ξη2 L (x) 2 2 + ξζ2η2ρ2 O(η2), which completes the proof. Published as a conference paper at ICLR 2023 Lemma I.6. Under condition of Theorem I.1, assuming there exists t GF such that L(x(t GF)) µh2 32 and x(t GF) Kh/4, then there exists t DEC = t GF + O(ln(1/ρ)/η), such that x(t DEC) is in O(ρ) neighbor of Γ, quantitatively, we have that L(x(t DEC)) 2 4ζρ . Moreover the movement of the projection of Φ(x( )) on the manifold is bounded, Φ(x(t GF)) Φ(x(t DEC)) 2 = O(η ln(1/ρ)) . Proof of Lemma I.6. Choose t DEC as the minimal t t GF such that L(x(t DEC)) 2 4ζρ. Define C = ln1 ηµ 8 (64ρ2/h2) = O(ln(1/ρ)/η). We will first perform an induction on t min{t DEC, t GF + C} = t GF + O(ln(1/ρ)/η) to show that L(x(t)) (1 ηµ/8)t t GFL(x(t GF)) Φ(x(t)) Φ(x(t GF)) = O(η2(t t GF)) For t = t GF, the result holds trivially. Suppose the induction hypothesis holds for t. Then by F.1 and Taylor Expansion, Φ(x(t)) x(t) 2L(x(t GF)) Then we have that dist(K, x(t)) dist(K, x(t GF)) + x(t GF) Φ(x(t GF)) 2 + Φ(x(t GF)) Φ(x(t)) + Φ(x(t)) x(t) 3h/4 + O(η2(t t GF)) = 3h/4 + O(η ln(1/ρ)) h . That is x(t) Kh. Then as t t DEC, L(x(t)) 2 4ζρ. Then by Lemma I.5, we have that L(x(t + 1)) (1 ηµ/8)L(x(t)) (1 ηµ/8)t+1 t GFL(x(t GF)) , Φ(x(t + 1)) Φ(x(t GF)) Φ(x(t + 1)) Φ(x(t)) + Φ(x(t)) Φ(x(t GF)) O(η2(t t GF)) , which completes the induction. Now if t DEC t GF + C = t GF + Ω(ln(1/ρ)/η), As the result of the induction, we have that L(x(t GF + C)) (1 ηµ 8 )CL(x(t GF)) 64ρ2 h2 L(x(t GF)) 8ρ2µ . By Lemma F.2, we have that L(x(t GF+C)) 2 ζ q 2L(x(t GF+C)) µ = 4ζρ, which leads to a contradiction. Hence we have that t DEC t GF + C = t GF + O(ln(1/ρ)/η). By induction, we have that Φ(x(t DEC)) Φ(x(t GF)) = O(η2(t DEC t GF)) = O(η ln(1/ρ)) . This completes the proof. I.1.3 ENTERING INVARIANT SET We first introduce some notations that is required for the proof in this and following subsection. Define ˆx = x Φ(x) , A(x) = 2L (Φ(x)) , x = A(x)ˆx , xj = x, vj(x) , P (j:D)(x) = i=j vi(x)v T i (x) . Published as a conference paper at ICLR 2023 Note x L(x) for x near the manifold Γ. We also use x(t), A(t) and ˆx(t) to denote x(t), A(x(t)) and ˆ x(t). Recall the original definition of Rj(x) is i=j λ2 i (x) vi(x), x Φ(x) 2 ηρλ2 j(x) , Based on the above notions, we can rephrase the notion R as Rj(x) = P (j:D)(x) x ηρλ2 j(x) . We additionally define the approximate invariant set Ij as Ij = { P (j:D)(x) x ηρλ2 j(x) + O(ηρ2)} . Lemma I.7. Assuming t satisfy that x(t) Kh, then we have that µ 2 x(t) Φ(x(t)) x(t) ζ x(t) Φ(x(t)) Proof of Lemma I.7. First by Lemma F.4, Φ(x(t)) Kr, hence x(t) = 2L(Φ(x(t)))(x(t) Φ(x(t))) ζ x(t) Φ(x(t)) . x(t) = 2L(Φ(x(t)))(x(t) Φ(x(t))) µ P Φ(x(t)),Γ(x(t) Φ(x(t))) . By Lemma F.4 and Lemma E.6, we have x(t) Φ(x(t)) P Φ(x(t)),Γ(x(t) Φ(x(t))) + PΦ(x(t)),Γ(x(t) Φ(x(t))) 4µ2 x(t) Φ(x(t)) 2 + 1 2 x(t) Φ(x(t)) + 1 Hence x(t) Φ(x(t)) 2 Lemma I.8. Assuming t satisfy that x(t) Kh and x(t) 2 = O(ρ), then we have that Φ(x(t + 1)) Φ(x(t)) = O(ηρ2) . Proof of Lemma I.8. By Lemma I.7, we have x(t) Φ(x(t)) = O(ρ). By Lemma F.7, we have that Φ(x(t + 1)) Φ(x(t)) ζξηρ x Φ(x) 2 + ζ2ξη2 x Φ(x) 2 2 + νηρ2 + ξζ2η2ρ2 Lemma I.9. Assuming t satisfy x(t) Kh/2 and x(t) Φ(x(t)) 2 = O(ρ), define x as x (t) = x(t) and for τ t, x (τ + 1) = x (τ) η 2L(Φ(x(t)))(x (τ) Φ(x(t))) ηρ 2L(Φ(x(t))) 2L(Φ(x(t)))(x (τ) Φ(x(t))) 2L(Φ(x(t)))(x (τ) Φ(x(t))) 2 x (t + 1) x(t + 1) 2 = O(ηρ2) and further if x(t + 1) Φ(x(t + 1)) 2 = Ω(ηρ), then x (t + 2) x(t + 2) 2 = O(ηρ2). Published as a conference paper at ICLR 2023 Proof of Lemma I.9. By x(t) Φ(x(t)) = O(ρ), x(t) Kh/2, and Lemma F.7, we have that x(t + 1) x(t) = O(ηρ) and hence x(t+1) K3h/4. This also implies x(t+1) Φ(x(t+1)) 2 = O(ρ). Similarly we have x(t + 2) K3h/4. For k {1, 2}, by Taylor Expansion, x(t + k + 1) =x(t + k) η L(x(t + k)) ηρ 2L(x(t + k)) L (x(t + k)) L (x(t + k)) + O(ηρ2) =x(t + k) η 2L(Φ(x(t + k)))(x(t + k) Φ(x(t + k))) + O(ηρ2) ηρ 2L(Φ(x(t + k))) L (x(t + k)) L (x(t + k)) + O(ηρ2) =x(t + k) η 2L(Φ(x(t + k)))(x(t + k) Φ(x(t + k))) ηρ 2L(Φ(x(t + k))) L (x(t + k)) L (x(t + k)) + O(ηρ2). Now by Lemmas I.7 and I.8, Φ(x(t + k)) Φ(x(t)) 2 = O(ηρ2), x(t + k + 1) =x(t + k) η 2L(Φ(x))(x(t + k) Φ(x(t))) ηρ 2L(Φ(x)) L (x(t + k)) L (x(t + k)) + O(ηρ2). (27) Now we first prove the first claim, we have for k = 0, x(t+k) x (t+k) 2 = 0, by Lemma F.4 and eq. 27, x(t + 1) =x(t) η 2L(Φ(x))(x(t) Φ(x)) ηρ 2L(Φ(x))(x(t) Φ(x)) 2L(Φ(x))(x(t) Φ(x)) 2 + O(ηρ2) =x (t + 1) + O(ηρ2). The second claim is slightly more complex. By the first claim and Lemma F.4, we have that L (x(t + 1)) L (x(t + 1)) = 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2 +O( x(t + 1) Φ(x(t + 1)) 2). (28) We first show 2L(Φ(x(t+1)))(x(t+1) Φ(x(t+1))) 2 is of order x(t+1) Φ(x(t+1)) 2 = Ω(ρ2) to show that the normalized gradient term is stable with respect to small perturbation, 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2 PΦ(x(t+1)),Γ 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2 2L(Φ(x(t + 1)))PΦ(x(t+1)),Γ(x(t + 1) Φ(x(t + 1))) 2 µ PΦ(x(t+1)),Γ(x(t + 1) Φ(x(t + 1))) 2 µ( (x(t + 1) Φ(x(t + 1))) 2 P Φ(x(t+1)),Γ(x(t + 1) Φ(x(t + 1))) 2) µ( (x(t + 1) Φ(x(t + 1))) 2 νζ 4µ2 x(t + 1) Φ(x(t + 1)) 2 2) 2 (x(t + 1) Φ(x(t + 1))) 2 = Ω(ηρ). Based on Lemma F.7, we have Φ(x(t + 1)) Φ(x(t)) = O(ηρ2). We further have by the first claim and Lemma I.8, 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2L(Φ(x))(x (t + 1) Φ(x(t))) = 2L(Φ(x))(x(t + 1) Φ(x(t + 1))) 2L(Φ(x))(x (t + 1) Φ(x(t)) + O( x(t + 1) Φ(x(t + 1)) 2 Φ(x(t + 1)) Φ(x) 2) = 2L(Φ(x))(x(t + 1) Φ(x(t + 1))) 2L(Φ(x))(x (t + 1) Φ(x(t))) + O(ηρ3) = 2L(Φ(x))(x(t + 1) x (t + 1)) + 2L(Φ(x))(Φ(x(t + 1)) Φ(x(t))) + O(ηρ3) Published as a conference paper at ICLR 2023 This implies 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2L(Φ(x(t + 1)))(x(t + 1) Φ(x(t + 1))) 2 = 2L(Φ(x))(x (t + 1) Φ(x)) 2L(Φ(x))(x (t + 1) Φ(x)) 2 + O(ρ) Combining with Equation 28, we have L (x(t + 1)) L (x(t + 1)) = 2L(Φ(x))(x (t + 1) Φ(x)) 2L(Φ(x))(x (t + 1) Φ(x)) 2 + O(ρ) By the above approximation and Equation 27, x(t + 2) = x (t + 2) + O(ηρ2) . Lemma I.10. Assuming t satisfy that x(t) K3h/4 and x(t) 2 = O(ρ), then we have that x(t + 1) x(t) + ηA(t) x(t) + ηρA2(t) x(t) x(t) 2 = O(ηρ2) . Proof of Lemma I.10. By Lemma I.9, we know x(t + 1) x(t) + η x(t) + ηρA(t) x(t) x(t) O(ηρ2) . This implies A(t)(x(t + 1) Φ(x(t))) x(t) + ηA(t) x(t) + ηρA2(t) x(t) = A(t)(x(t + 1) x(t) + η x(t) + ηρA(t) x(t) ζ x(t + 1) x(t) + η x(t) + ηρA(t) x(t) x(t) = O(ηρ2) . (29) We also have x(t + 1) A(t)(x(t + 1) Φ(x(t))) =(A(t + 1) A(t))(x(t + 1) Φ(x(t + 1))) A(t)(Φ(x(t)) Φ(x(t + 1))) Plugging in Equation 29, we have that x(t + 1) x(t) + ηA(t) x(t) + ηρA2(t) x(t) x(t) 2 = O(ηρ2) . Lemma I.11. Under condition of Theorem I.1, assuming there exists t DEC such that x(t DEC) Kh/2 and L(x(t DEC)) 4ζρ, then there exists t DEC2 = t DEC + O(ln(1/η)/η), such that x(t DEC2) is in I1 K3h/4. Furthermore, for any t satisfying t DEC2 t t DEC2 + Θ(ln(1/η)/η), we have that x(t) I1 K3h/4 and Φ(x(t)) Φ(x(t DEC)) = O(ρ2 ln(1/η)). Proof of Lemma I.11. For simplicity, denote C = ln1 ηµ ηµ3 4ζ2 +Θ(ln(1/ρ)/η) = O(ln(1/η)/η). Here the quantity Θ(ln(1/ρ)/η) is the same quantity in the statement of the lemma. We will prove the induction hypothesis for t DEC t t DEC + 2C, x(t 1) ηρλ2 1(t), t > t DEC x(t) (1 ηµ) x(t 1) , x(t 1) ηρλ2 1(t 1), t > t DEC x(t) ηρλ2 1(t) + O(ηρ2), Φ(x(t)) Φ(x(t DEC)) O(ηρ2(t t DEC)), x(t) K3h/4. Published as a conference paper at ICLR 2023 The induction hypothesis holds trivially for t = t DEC. Assume the induction hypothesis holds for t t. By Lemmas F.1 and I.7, x(t DEC) 2 ζ x(t DEC) Φ(x(t DEC)) ζ µ L(x(t DEC)) 4ζ2 µ ρ. Combining with the induction hypothesis, we have x(t) By x(t) K3h/4 and Lemma I.8, we have that Φ(x(t + 1)) Φ(x(t)) O(ηρ2) . Hence we have that Φ(x(t + 1)) Φ(x(t DEC)) Φ(x(t + 1)) Φ(x(t)) + Φ(x(t)) Φ(x(t DEC)) O(ηρ2(t + 1 t DEC)). (30) This proves the third statement of the induction hypothesis. By x(t) = O(ρ) and Lemma I.10, we have that x(t + 1) x(t) + ηA(t) x(t) + ηρA2(t) x(t) x(t) 2 = O(ηρ2) . Analogous to the proof of Lemmas H.1 and H.2, we have 1. If x(t) > ηρλ2 1(t), we would have x(t) ηA(t) x(t) ηρA2(t) x(t) x(t) I ηA(t) ηρA2(t) 1 x(t) x(t) max{ηλ1, 1 ηλD ηρλ2 D 1 x(t) } max{(1 ηλD) x(t) ηρλ2 D, ηλ1 x(t) } max{(1 ηµ) x(t) ηρµ2, ηζ x(t) } Hence we have x(t + 1) max{(1 ηµ) x(t) ηρµ2, ηζ x(t) } + O(ηρ2) (1 ηµ) x(t) . 2. If x(t) 2 ηρλ2 1(t), then by Lemma H.1, we have that x(t) ηA(t) x(t) ηρA2(t) x(t) x(t) 2 ηρλ2 1(t) . Hence by Lemma K.1 x(t + 1) ηρλ2 1(t) + O(ηρ2) ηρλ2 1(t + 1) + O(ηρ2) . Concluding the two cases, we have shown the first and second claim of the induction hypothesis holds. Hence we can show that x(t + 1) 4ζ2 µ ρ. Then by Lemma I.7, we have that x(t + 1) Φ(x(t + 1)) 8ζ2 As t t DEC + 2C = t DEC + O(ln(1/η)/η), by Equation 30, Φ(x(t + 1)) Φ(x(t DEC)) O( ρ2 ln η) . This implies dist(x(t + 1), K) dist(x(t DEC), K) + x(t DEC) Φ(x(t DEC)) + Φ(x(t DEC)) Φ(x(t + 1)) + x(t + 1) Φ(x(t + 1)) =h/2 + O(ρ2 ln(1/η)) + O(ρ) 3h/4. This proves the fourth claim of the inductive hypothesis. The induction is complete. Published as a conference paper at ICLR 2023 Now define t DEC2 the minimal t t DEC, such that x(t) ηρλ2 1(t). If t DEC2 > t DEC + C, then by the induction, Lemmas F.1 and I.7, x(t DEC + C) (1 ηµ)C x(t DEC) 4ζ2 x(t DEC) 4ζ2 ζ x(t DEC) Φ(x(t DEC) 4ζ L(t DEC) λ2 1(t DEC + C)ηρ . This is a contradiction. Hence we have t DEC2 t DEC + C. By the induction hypothesis x(t DEC2) I1 K3h/4. Furthermore by induction, for any t satisfying t DEC2 t t DEC + 2C, we have that x(t) ηρλ2 1(t) + O(ηρ2) . By the induction hypothesis x(t) I1 K3h/4 and Φ(x(t)) Φ(x(t DEC)) = O(ρ2 ln(1/ρ)). Lemma I.12. Under condition of Theorem I.1, assuming t satisfy that x(t) I1 K3h/4, then we have that ( Rk(x(t)) 0 Rk(x(t + 1)) + λ2 k(t + 1)ηρ (1 ηµ)(Rk(x(t)) + λ2 k(t)ηρ), Rk(x(t)) 0 Rk(x(t + 1)) O(ηρ2). Proof of Lemma I.12. As x(t) I1, x(t) 2 ζηρ + O(ηρ2). As x(t) 2 = O(ρ), we have x(t)x(t + 1) Kh and Φ(x(t))Φ(x(t + 1)) Kr. We will begin with a quantization technique separating [M] into disjoint continuous subset S1, ..., Sp such that i = j, min k Si,l Sj |λk(t) λl(t)| ρ . By Lemmas I.8 and K.1, we have that for any n [M], |λk(t) λk(t + 1)| = O( 2L(Φ(x(t))) 2L(Φ(x(t + 1))) ) = O( Φ(x(t)) Φ(x(t + 1)) ) This implies min k Si,l Sj |λk(t + 1) λl(t + 1)| ρ O(ηρ2) 0.99ρ . Define P (t) S(i) X k Si vn(t)vn(t)T . By Theorem K.3, for any k, P (t) Sk P (t+1) Sk O( 2L(Φ(x(t))) 2L(Φ(x(t + 1))) ρ ) = O(ηρ) . By Lemma I.10, we have that x(t + 1) x(t) + ηA(t) x(t) + ηρA2(t) x(t) x(t) 2 = O(ηρ2) . We will write x (t + 1) as shorthand of x(t) ηA(t) x(t) ηρA2(t) x(t) x(t) . Now we discuss by cases, Published as a conference paper at ICLR 2023 1. If q Pp i=j P (t) S(i) x(t) 2 > maxk Sj λ2 k(t)ηρ > µ2ηρ, by Lemma H.3, v u u t i=j P (t) S(i) x(t + 1) 2 i=j P (t) S(i)x (t + 1) 2 + O(ηρ2) max{ 1 ηλD(t + 1) i=j P (t) S(i) x(t) ηρλD(t + 1)2 Pp i=j P (t) S(i) x(t) η max k Sj λk(t + 1) i=j P (t) S(i) x(t) } + O(ηρ2) i=j P (t) S(i) x(t) ηρµ3 i=j P (t) S(i) x(t) } + O(ηρ2) . This further implies v u u t i=j P (t+1) S(i) x(t + 1) 2 i=j P (t) S(i) x(t + 1) 2 + O(ηρ x(t + 1) ) i=j P (t) S(i) x(t) ηρµ3 i=j P (t) S(i) x(t) } + O(ηρ2) i=j P (t) S(i) x(t) . 2. If q Pp i=j P (t) S(i) x(t) 2 maxk Sj λ2 k(t)ηρ, then by Lemma H.1, we have that i=j P (t) S(i)x (t + 1) 2 ηρ max k Sj λ2 k(t) . Hence we have that v u u t i=j P (t) S(i) x(t + 1) 2 i=j P (t) S(i)x (t + 1) 2 + O(ηρ2) max k Sj λ2 k(t)ηρ + O(ηρ2) max k Sj λ2 k(t + 1)ηρ + O(ηρ2) . This further implies v u u t i=j P (t+1) S(i) x(t + 1) 2 i=j P (t) S(i) x(t + 1) 2 + O(ηρ x(t + 1) ) max k Sj λ2 k(t)ηρ + O(ηρ2) max k Sj λ2 k(t + 1)ηρ + O(ηρ2) . Finally taking into quantization error, as all the eigenvalue in the same group at most differ Dρ, for any i Sj, we have that λ2 i (t + 1) + maxk Sj λ2 k(t + 1) 2Dζρ + D2ρ2. Hence the previous discussion concludes as 1. If Rk(x(t)) 0 Rk(x(t + 1)) + λ2 k(t + 1)ηρ (1 ηµ)(Rk(x(t)) + λ2 k(t).ηρ) Published as a conference paper at ICLR 2023 2. If Rk(x(t)) < 0 Rk(x(t + 1)) O(ηρ2). Lemma I.13. Under condition of Theorem I.1, assuming there exists t DEC2 such that for any t satisfying t DEC2 t t DEC2 + Θ(ln(1/η)/η), we have that x(t) I1 K3h/4 . Then there exists t INV = t DEC2 + O(ln(1/η)/η)) such that for any t satisfying t INV t t INV + Θ(ln(1/η)/η), we have that x(t) ( k [M]Ik) K7h/8 . Φ(x(t)) Φ(x(t DEC2)) = O(ρ2 ln(1/η)) . Proof of Lemma I.13. The proof is almost identical with Lemma I.11 replacing the first two iterative hypothesis to Lemma I.12 and is omitted here. I.2 PHASE II (PROOF OF THEOREM I.3) Proof of Theorem I.3. Let t ALIGN = O(ln(1/ρ)/η) be the quantity defined in Lemma I.19. We will inductively prove the following induction hypothesis P(t) holds for t ALIGN t T3/ηρ2 + 1, x(t) Kh/2, t ALIGN τ t | x(τ) Φ(x(τ)), v1(x(τ)) | = Θ(ηρ), t ALIGN τ t max j [2:M] | x(τ) Φ(x(τ)), vj(x(τ)) | = O(ηρ2), t ALIGN τ t Φ(x(τ)) X(ηρ2τ) = O(η ln(1/ρ)), t ALIGN τ t P(t ALIGN) holds due to Lemma I.19. Now suppose P(t) holds, then x(t + 1) Kh. By Lemma I.19 again, | x(t+1) Φ(x(t+1)), v1(x(t+1)) | = Θ(ηρ) and maxj [2:M] | x(t+1) Φ(x(t+1)), vj(x(t+1)) | = O(ηρ2) holds. Now by Lemma I.20, Φ(x(τ + 1)) Φ(x(τ)) + ηρ2P Φ(x(τ)),Γ λ1(t)/2 = O(ηρ3 + η2ρ2) , t ALIGN τ t. By Corollary L.3, let b(x) = Φ(x) λ1( 2L(x))/2, p = ηρ2 and ϵ = O(η + ρ), it holds that Φ(x(τ)) X(ηρ2τ) =O( Φ(x(t ALIGN)) Φ(xinit) + T3ηρ2 + (ρ + η)T3) =O(η ln(1/ρ)), t ALIGN τ t + 1 This implies x(t+1) X(ηρ2(t+1)) 2 x(t+1) Φ(x(t+1)) 2 + Φ(x(t+1)) X(ηρ2(t+1)) 2 = O(η ln(1/ρ)) < h/2. Hence x(t + 1) Kh/2. Combining with P(t) holds, we have that P(t + 1) holds. The induction is complete. Now P( T3/ηρ2 ) is equivalent to our theorem. I.2.1 ALIGNMENT TO TOP EIGENVECTOR We will continue to use the notations introduced in Appendix I.1.3. We further define S = {t| x(t) ηλ2 1 2 ηλ1 ρ + O(ηρ2)} , T = {t| x(t) 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 U = {t|Ω(ρ2) x1(t) 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 Here the constant in O depends on the constant in Ij and will be made clear in Lemma I.16. For s S, define next(s) as the smallest integer greater than s in S. Published as a conference paper at ICLR 2023 Lemma I.14. Under the condition of Theorem I.3, there exist constants C1, C2 < 1 independent of η and ρ, if x1(t) 2 1 2 ηλ1(t) + ηλ2(t)2 2 ηλ2(t) ρ and x(t) ( j [M]Ij) K7h/8, then x(t) 2 C1 ηλ2 1 2 ηλ1 ρ x(t + 1) 2 C2 ηλ2 1 2 ηλ1 ρ Proof of Lemma I.14. By Lemma I.10, if we write x (t+1) as shorthand of x(t) ηA(t) x(t) ηρA2(t) x(t) x(t) , then x(t + 1) x (t + 1) = O(ηρ2). Define Iquad j as {x|Rj(x) 0}. Then we can find a surrogate xsur(t) such that xsur(t) ( j [M]Iquad j ) Kh and xsur(t) x(t) 2 = O(ηρ2). We will write x sur(t + 1) as shorthand of xsur(t) ηA(t)xsur(t) ηρA2(t) xsur(t) ζ2 + 1 + (1 1 ζ2 + 1 ) max{ζ2 µ2 As h(1) < 1, we can choose C1 < 1, such that h(C1) < 1. We can further choose C2 = max{(h(C1) + 1)/2, 1 µ2 We will discuss by cases x(t) 2 ηλ4 1 λ2 1(1 ηλD) + (λ2 1 λ2 D)(1 ηλ1)ρ Then x(t) 2 ηλ2 1 2 ηλ1 ρ = λ2 1(2 ηλ1) λ2 1(1 ηλD) + (λ2 1 λ2 D)(1 ηλ1) = λ2 1(2 ηλ1) λ2 1(2 ηλ1 ηλD) λ2 D(1 ηλ1) 1 λ2 D λ2 1 1 ηλ1 2 ηλ1 1 + λ2 D λ2 1 1 ηλ1 2 ηλ1 1 + µ2 In such case we have x(t) xsur(t) xsur(t) = O(ρ) . Then we have x sur(t + 1) x (t + 1) = O(ηρ2). By Lemma H.5, we have that x(t + 1) 2 x(t + 1) x (t + 1) 2 + x(t + 1) x sur(t + 1) + x sur(t + 1) max( ηλ2 1 2 ηλ1 ρ ηρ λ4 D 2λ2 1 , ηρλ2 1 (1 ηλ1) x(t) 2) + O(ηρ2) max(1 λ4 D(2 ηλ1) 2λ4 1 , (2 ηλ1) (1 ηλ1)(1 + µ2 3ζ2 )) ηλ2 1 2 ηλ1 ρ 3ζ2 ) ηλ2 1 2 ηλ1 ρ C2 ηλ2 1 2 ηλ1 ρ . x(t) 2 ηλ4 1 λ2 1(1 ηλD) + (λ2 1 λ2 D)(1 ηλ1)ρ ηλ2 1 1 ηλ1 ρ. Then we have | ηρλ2 D + (1 ηλD) x(t) 2| ηρλ2 1 (1 ηλ1) x(t) 2 λ2 1 λ2 D λ2 1 . |ηρλ2 2 (1 ηλ2) x(t) 2| ηρλ2 1 (1 ηλ1) x(t) 2 λ2 2 λ2 1 . Published as a conference paper at ICLR 2023 By Lemma H.8, x (t + 1) 2 (ηρλ2 1 (1 ηλ1) x(t) 2) x2 1(t) 2 x(t) 2 2 + (1 x2 1(t) 2 x(t) 2 2 ) max{λ2 1 λ2 D λ2 1 , λ2 2 λ2 1 } (ηρλ2 1 (1 ηλ1) x(t) 2) x2 1(t) 2 x(t) 2 2 + (1 x2 1(t) 2 x(t) 2 2 ) max{ζ2 µ2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 For x(t) 2 ηλ2 1 2 ηλ1 ρC1, λ2 2(2 ηλ1) λ2 1(2 ηλ2) + 1 /C1 1 λ2 2 λ2 1 + 1 /C1 1 2C1 After plugging in, we have that x(t + 1) 2 x (t + 1) 2 + O(ηρ2) h(C1) ηλ2 1 2 ηλ1 ρ + O(ηρ2) C2 ηλ2 1 2 ηλ1 ρ. This concludes the proof. Lemma I.15. Under the condition of Theorem I.3, for any t 0 satisfying that (1) x(t) ( j [M]Ij) Kh, (2) t S, it holds that t + 1 S. Moreover, if | x1(t)| Ω(ρ2) and x(t) 2 ηρλ2 1 Ω(ρ2), then it holds that x1(t + 1) Ω(ρ2). Proof of Lemma I.15. As t S, it holds that x(t) ηλ2 1 2 ηλ1 ρ + Θ(ηρ2). By Lemma I.10, if we write x (t + 1) as shorthand of x(t) ηA(t + 1) x(t) ηρA2(t) x(t) x(t) , then x(t + 1) x (t + 1) = O(ηρ2). Define Iquad j as {x|Rj(x) 0}. Then we can find a surrogate xsur(t) such that xsur(t) ( j [M]Iquad j ) Kh, and xsur(t) x(t) 2 = O(ηρ2). We will write x sur(t + 1) as shorthand of xsur(t) ηA(t)xsur(t) ηρA2(t) xsur(t) As x(t) = Ω(ηρ), we have xsur(t) x(t) x(t) 2 = O(ρ) . Hence we have that x(t+1) x sur(t+1) = x(t+1) x (t+1) + x (t+1) x sur(t+1) = O(ηρ2) Notice we have xsur(t) 2 ηλ2 1 2 ηλ1 ρ for properly chosen function in the definition S, hence, by Lemma H.5 x sur(t + 1) 2 ηλ2 1 2 ηλ1 ρ. This further implies t + 1 S. We also have | x sur(t + 1), v1 | = | xsur(t), v1 ηλ1 xsur(t), v1 ηρλ2 1 xsur(t), v1 = | xsur(t), v1 |(ηλ1 + ηρλ2 1 xsur(t) 1) We will discuss by cases. Let C satisfies that C = q 1 2 (λ4 2 + λ4 1). Published as a conference paper at ICLR 2023 1. If xsur(t) Cηρ, then as we have λ2 1 C | x sur(t + 1), v1 | | xsur(t), v1 |(λ2 1 C 1) Ω(ρ2). 2. If xsur(t) Cηρ, then as x(t) I2, we have that | xsur(t), v1 | Ω(ηρ). Then as xsur(t) x(t) 2 + O(ηρ2) λ2 1ηρ Ω(ρ2), we have that | x sur(t + 1), v1 | | xsur(t), v1 |( λ2 1ηρ λ2 1ηρ Ω(ρ2) 1) Ω(ρ2). By previous approximation results, we have that x1(t + 1) Ω(ρ2). Lemma I.16. Under the condition of Theorem I.3, for any t 0 satisfying that (1) x(t) ( j [M]Ij) K15h/16, (2) t S, it holds that next(t) is well defined and next(t) t + 2. Proof of Lemma I.16. Following similar argument in Lemma H.1, we have that x(t+1) ( j [M]Ij) Kh. If t + 1 S, then we can apply Lemma I.15 to show that t + 2 S. Lemma I.17. Under the condition of Theorem I.3, there exists constant C > 0 independent of η and ρ, assuming that (1) x(t) ( j [M]Ij) K7h/8, (2) t S, (3) Ω(ρ2) x1(t) , then x1(next(t)) x1(t) O(ηρ2) . Proof of Lemma I.17. This is by standard approximation as in previous proof and Lemma H.9. Lemma I.18. Under the condition of Theorem I.3, there exists constant C > 0 independent of η and ρ, assuming that (1) x(t) ( j [M]Ij) K7h/8, (2) t S (3) Ω(ρ2) x1(t) 1 2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 x1(next(t)) min{(1 + Cη) x1(t) , 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 or x1(next(next(t))) min{(1 + Cη) x1(t) , 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 Proof of Lemma I.18. In this proof, we will sometime drop the t in λk(t) or A(t). Applying Lemma I.16, we have next(t) and next(next(t)) are well-defined. We can suppose x1(next(t)) 2 1 2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 , else the result holds already. By assumption, we have x1(t) Ω(ρ2). Using Lemma I.10, x(t + 1) x(t) + ηA x(t) + ηρA2 x(t) x(t) O(ηρ2) . x (t + 1) = x(t) + ηA x(t) + ηρA2 x(t) as the one step update of SAM on the quadratic approximation of the general loss. Now using Lemma I.14 and the induction hypothesis, we have for some C1 and C2 smaller than 1, x(t) C1 ηλ2 1 2 ηλ1 ρ x (t + 1) C2 ηλ2 1 2 ηλ1 ρ. We will discuss by cases, 1 If x(t) C1 ηλ2 1 2 ηλ1 ρ If next(t) = t + 1, then x 1(t + 1) x1(t) = ηρλ2 1 (1 ηλ1) x(t) x(t) (2 C1) ηλ1 + C1ηλ1 Published as a conference paper at ICLR 2023 As we have x1(t) = Ω(ρ2), we have x 1(t + 1) = Ω(ρ2), then as x1(t + 1) x 1(t + 1) = O(ηρ2), this implies x1(t + 1) x 1(t + 1) O(ηρ2) 1 C1 x1(t + 1) O(ηρ2) 1 C1 + 1) x1(t) . If next(t) = t + 2, define x (t + 2) = x (t + 1) ηAx (t + 1) ηρA2 x (t+1) x (t+1) , as x1(t + 1) = Ω(ηρ), by Lemma I.9, we have x (t + 2) x(t + 2) = O(ηρ2). x1(t) = (ηρλ2 1 (1 ηλ1) x(t) )(ηρλ2 1 (1 ηλ1) x (t + 1) ) x(t) x (t + 1) (ηρλ2 1 (1 ηλ1) x(t) ) ηρλ2 1 (1 ηλ1) ηρλ2 1 (1 ηλ1) x(t) x(t) (ηρλ2 1 (1 ηλ1) x(t) ) = ηρλ2 1 (1 ηλ1) ηρλ2 1 (1 ηλ1) x(t) (1 ηλ1)2 + ηλ1 C1 (2 ηλ1) 1 + 4Cη. Combining with | x1(t)| Ω(ρ2), we have that x1(next(t)) (1 + Cη) x1(t) 2 Case 2 x(t) > C1 ηλ2 1 2 ηλ1 ρ, then x(t + 1) C2 ηλ2 1 2 ηλ1 ρ, next(t) = t + 1 By Lemma I.17, x1(t + 1) (1 Cη) x1(t) . As x(next(t)) C2 ηλ2 1 2 ηλ1 , similar to the first case, x1(next(next(t))) (1 + 4Cη) x1(next(t)) (1 + Cη) x1(t) . In conclusion, if x1(t) 1 2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 ρ, we would have there exists C > 0 x1(next(t)) (1 + Cη) x1(t) or x1(next(next(t))) (1 + Cη) x1(t) . Lemma I.19. Under the condition of Theorem I.3, there exists constant T2 > 0 independent of η and ρ, we would have that when t = t ALIGN = T2 ln(1/ρ)/η , | x(t) Φ(x(t)), v1(x(t)) | = Θ(ηρ) , max j [2:M] | x(t) Φ(x(t)), vj(x(t)) | = O(ηρ2) . Further if x(t ) Kh holds for t = 0, 1, ..., t LOCAL, then for t satisfying t ALIGN t t LOCAL | x(t) Φ(x(t)), v1(x(t)) | = Θ(ηρ) , max j [2:M] | x(t) Φ(x(t)), vj(x(t)) | = O(ηρ2) . Proof of Lemma I.19. Let C be the constant defined in Lemma I.18. By Lemma I.15, we can suppose WLOG x1(0) ρ2 and 0 S. Define C1 log1+Cη( ηλ2 1 2 ηλ1 /ρ) C2 C1 + lnmax{1 µ2 ζ2 = O(log(1/ρ)/η). We will choose t ALIGNMID as the minimal t S, such that x1(t) 1 2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 Published as a conference paper at ICLR 2023 Then by induction and Lemmas I.17 and I.18, we easily have that for t min{C2 + 1, t ALIGNMID} and t S, we have that x(t) K7h/8 ( j Ij) , x1(t) min{(1 + Cη)t/4 x1(0) , 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 or x1(next(t)) min{(1 + Cη)t/4 x1(0) , 1 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 The detailed induction is analogous to previous inductive argument and is omitted. If t ALIGNMID C1, then we have for the minimal t C1 and t S x1(t) ηλ2 1 2 ηλ1 ρ . This is a contradiction and we have that t ALIGNMID C1. By Lemma I.17, x1(next(t)) x1(t) O(ηρ2) for x1(t) 1 2 ηλ2 1 2 ηλ1 + ηλ2 2 2 ηλ2 ρ and t S and then by Lemma I.18, x(t) x1(t)) 1 ηλ2 1 2 ηλ1 + 3 ηλ2 2 2 ηλ2 for C2 t t ALIGNMID. We will then show that for t t ALIGNMID + C1 iteration, P (2:D) x(t + 1) O(ηρ2). For C1 t t ALIGNMID, 1 ηλ2 ηρ λ2 2 x(t) 1 ηλD ηρ λ2 D x(t) 1 λ2 D 2λ2 1 1 µ2 Notice that, 1 ηλ2 ηρ λ2 2 x(t) 1 ηλ2 ηρ λ2 2 x(t) 1 ηλ2 4λ2 2 λ2 1 + 3λ2 2 (2 ηλ2) 1 + 2(λ2 1 λ2 2) λ2 1 + 3λ2 2 1 + 2 P (2:D)(t)x (t + 1) 2 max{1 µ2 2ζ2 } P (2:D)(t) x(t) 2 Now by Lemma K.1 and Theorem K.3, P (2:D)(t) P (2:D)(t + 1) O(ηρ2) v1(t) v1(t + 1) O(ηρ2) λ1(t) λ1(t + 1) O(ηρ2) By Lemma I.10, we have that x (t + 1) x(t + 1) = O(ηρ2). Combining the above, it holds that P (2:D)(t + 1) x(t + 1) max{1 µ2 4ζ2 } P (2:D)(t) x(t) + O(ηρ2) Hence when t = t ALIGN = t ALIGNMID + C2, x(t) x1(t) Ω(ηρ) , P (2:D)(t) x(t) O(ηρ2) . By x(t) I1, we easily have x1(t) = O(ηρ). Hence we conclude that x1(t) = Θ(ηρ) , P (2:D)(t) x(t) = O(ηρ2) . The second claim is just another induction similar to previous steps and is omitted as well. Published as a conference paper at ICLR 2023 I.2.2 TRACKING RIEMANNIAN GRADIENT FLOW We are now ready to show that Φ(x(t)) will track the solution of Equation 7. The main principal of this proof has been introduced in Section 4.3. Lemma I.20. Under the condition of Theorem I.3, for any t satisfying that x(t) Kh, x1(t) = Θ(ηρ), P (2:D)(t) x(t) = O(ηρ2) , it holds that Φ(x(t + 1)) Φ(x(t)) + ηρ2P Φ(x(t)),Γ λ1(t)/2 O(ηρ3 + η2ρ2) . Proof of Lemma I.20. To begin with, we can approximate Φ(x(t + 1)) Φ(x(t)) by its first order Taylor Expansion, by Lemma F.7, Φ(x(t + 1)) Φ(x(t)) Φ(x(t))(x(t + 1) x(t)) = O( x(t + 1) x(t) 2) = O(η2ρ2) . Then by plugging in the update rule and another Taylor Expansion, Φ(x(t))(x(t + 1) x(t)) ηρ Φ(x(t)) 2L (x) L (x) L (x) ηρ2 Φ(x(t)) 2L (x)[ L (x) L (x) , L (x) L (x) ]/2 2 = O(ηρ3). Using Lemma F.3, we have ηρ Φ(x(t)) 2L (x) L (x) L (x) = ηρ L (x) 2Φ(x(t)) L (x) L (x) , L (x) L (x) = O(ηρ L (x) ) . Putting together, we have that Φ(x(t + 1)) Φ(x(t)) ηρ2 Φ(x(t)) 2L (Φ(x(t)))[ L (x) L (x) , L (x) L (x) ]/2 O(η2ρ2 + ηρ3) + O(ηρ L (x) ) . As we have x(t) = Θ(ηρ), hence by Lemmas F.2 and I.7, Φ(x(t + 1)) Φ(x(t)) ηρ2 Φ(x(t)) 2L (Φ(x(t)))[ L (x(t)) L (x(t)) , L (x(t)) L (x(t)) ]/2 O(ηρ3 + η2ρ2) Finally, we have that ηρ2 Φ(x(t)) 2L (Φ(x(t)))[ L (x(t)) L (x(t)) , L (x(t)) L (x(t)) ]/2 ηρ2 Φ(x(t)) 2L (Φ(x(t)))[v1(t), v1(t)]/2 O(ηρ3) as the angle between L(x) L(x) and v1(t) is O(ρ). By Lemma F.3, it holds that Φ(x(t)) 2L (Φ(x(t)))[v1(t), v1(t)] = P X,Γ (λ1(t)) Putting together we have that, Φ(x(t + 1)) Φ(x(t)) + ηρ2P X,Γ λ1(t)/2 O(ηρ3 + η2ρ2). It completes the proof. Published as a conference paper at ICLR 2023 I.3 PROOF OF THEOREM 4.5 Proof of Theorem 4.5. By Theorem I.1, there exists constant T1 independent of η, ρ, such that for any T 1 > T1 independent of η, ρ, it holds that max T1 ln(1/ηρ) ηt T 1 ln(1/ηρ) max j [M] Rj(x(t)) = O(ηρ2). max T1 ln(1/ηρ) ηt T 1 ln(1/ηρ) Φ(x(t)) Φ(xinit) = O((η + ρ) ln(1/ηρ)). By Assumption I.2, there exists step T1 ln(1/ηρ) ηt PHASE T 1 ln(1/ηρ), such that max j [M] Rj(x(t PHASE)) = O(ηρ2), Φ(x(t PHASE)) Φ(xinit) = O((η + ρ) ln(1/ηρ)), | x(t PHASE) Φ(x(t PHASE)), v1(x(t PHASE)) | Ω(ρ2). x(t PHASE) 2 λ1(t PHASE)ηρ Ω(ρ2). Hence by Theorem I.3, if we consider a translated process with x (t) = x(t + t PHASE), we would have for any T3 such that the solution X of Equation 7 is well defined, we have that for t = T3 Φ(x (t)) X(ηρ2t) 2 = O(η ln(1/ρ)) . This implies for t satisfying X(ηρ2(t t PHASE)) is well-defined, Φ(x(t)) X(ηρ2(t t PHASE)) 2 = O(η ln(1/ρ)). Finally, as X(ηρ2(t t PHASE)) X(ηρ2t) 2 = O(ηρ2t PHASE) = O(ρ ln(1/ηρ)) = O(η ln(1/ρ)). We have that Φ(x(t)) X(ηρ2t) 2 = O(η ln(1/ρ)). The alignment result is a direct consequence of Theorem I.3. I.4 PROOFS OF COROLLARIES 4.6 AND 4.7 Proof of Corollary 4.6. We will do a Taylor expansion on LMax ρ . By Theorem I.1 and I.3, we have x( T3/ηρ2 )) X(T3) = O(η + ρ) and x( T3/ηρ2 )) Φ(x( T3/ηρ2 ))) 2 = O(ηρ). For convenience, we denote x( T3/ηρ2 ) by x. RMax ρ (x) = max v 2 1 ρv T L(x) + ρ2v T 2L(x)v/2 + O(ρ3) Since max v 2 1 v T L(x) 2 = O( x Φ(x) 2) = O(ηρ), it holds that RMax ρ (x) = ρ2 max v 2 1 v T 2L(x)v/2 + O(η2ρ2 + ρ3) = ρ2λ1( 2L(x)) + O(η2ρ2 + ρ3) = ρ2λ1( 2L(X(T3))) + O(ηρ2), which completes the proof. Proof of Corollary 4.7. We choose T such that X(Tϵ) is sufficiently close to X( ), such that λ1(X(Tϵ)) λ1(X( )) + ϵ/2. By Corollary 4.6 (let T3 = Tϵ), we have that for all ρ, η such that η ln(1/ρ) and ρ/η are sufficiently small, RMax ρ (x( Tϵ/(ηρ2) )) ρ2λ1(X(Tϵ))/2 o(1). This further implies RMax ρ (x( Tϵ/(ηρ2) )) ρ2λ1(X( ))/2 ϵρ2+o(1). We also have L(x( Tϵ/(ηρ2) )) infx U L(x) = o(1). Then we can leverage Theorem G.6 and Theorem G.3 to get the desired bound. Published as a conference paper at ICLR 2023 I.5 DERIVATIONS FOR SECTION 4.3 We will first show our derivation of Equation 9. In Phase II, x(t) is O(ηρ)-close to the manifold Γ and therefore it can be shown that x(t) Φ(x(t)) 2 = O(ηρ) holds for every step in Phase II. This also implies that x(t + 1) x(t) 2 = O(ηρ) (See Lemma F.7). Using Taylor expansion around x(t), we have that Φ(x(t + 1)) Φ(x(t)) = Φ(x(t))(x(t + 1) x(t)) + O( x(t + 1) x(t) 2 2) = η Φ(x(t)) L x(t) + ρ L(x(t)) L(x(t)) 2 + O(η2ρ2) . (31) For any x RD, applying Taylor expansion on L x + ρ L(x) L(x) 2 around x, we have that L x + ρ L(x) L(x) 2 = L(x) + ρ 2L(x) L(x) L(x) 2 + ρ2 2 2( L)(x) L(x) L(x) 2 , L(x) L(x) 2 + O(ρ3). (32) Using Equation 32 with x = x(t), plugging in Equation 31 and then rearranging, we have that Φ(x(t + 1)) Φ(x(t)) + ηρ2 2 Φ(x(t)) 2( L)(x(t)) L(x(t)) L(x(t)) 2 , L(x(t)) L(x(t)) 2 = η Φ(x(t)) L(x(t)) ηρ Φ(x(t)) 2L(x(t)) L(x(t)) L(x(t)) 2 + O(η2ρ2 + ηρ3) . By Lemma 3.1, we have that Φ(x(t)) L(x(t)) = 0. Furthermore, by Lemma F.5, we have that Φ(Φ(x(t))) 2L(Φ(x(t))) = 0. This implies that Φ(x(t)) 2L(x(t)) = Φ(Φ(x(t))) 2L(Φ(x(t))) + O( x(t) Φ(x(t)) 2) = O(ηρ) . Thus we conclude that Φ(x(t + 1)) Φ(x(t)) = ηρ2 2 Φ(x(t)) 2( L)(x(t)) L(x(t)) L(x(t)) 2 , L(x(t)) L(x(t)) 2 +O(η2ρ2 + ηρ3) . (9) We will then show our derivation of Equation 10 Φ(x(t + 1)) Φ(x(t)) 2 Φ(x(t)) 2( L)(x(t)) L(x(t)) L(x(t)) 2 , L(x(t)) L(x(t)) 2 + O(η2ρ2 + ηρ3) 2 Φ(x(t)) 2( L)(x(t)) v1( 2L(x(t))), v1( 2L(x(t))) + O(η2ρ2 + ηρ3) 2 Φ(x(t)) λ1( 2L(x(t))) + O(η2ρ2 + ηρ3) 2 Φ(Φ(x(t))) λ1( 2L(Φ(x(t)))) + O(η2ρ2 + ηρ3), (10) where the second to last step we use the property of the derivative of eigenvalue (Lemma K.7) and the last step is due to Taylor expansion of Φ( ) λ1( 2L( )) at Φ(x(t)) and the fact that Φ(x(t)) x(t) = O(ηρ). We will finally show our derivation of Equation 12. The update of the gradient (Equation 12) can be viewed as an O(ηρ2)-perturbed version of the update of the iterate in the quadratic case. Note O(ηρ2) is a higher order term comparing to the other two terms, which are on the order of Θ(η2ρ) and Θ(ηρ) respectively. By controlling the error terms, the mechanism and analysis of the implicit alignment between Hessian and gradient still apply to the general case. We can also show that once this alignment happens, it will be kept until the end of our analysis, which is Θ(η 1ρ 2) steps. Published as a conference paper at ICLR 2023 Finally, we derive Equation 12 by Taylor expansion. We first apply Taylor expansion (Equation 32) on the update rule of the iterate of SAM (Equation 3): x(t + 1) = x(t) η L(x(t)) ηρ 2L(x(t)) L(x(t)) L(x(t)) 2 + O(ηρ2). (33) Since phase II happens in an O(ηρ)-neighborhood of manifold Γ, we have x(t + 1) x(t) 2 = O(ηρ). Then by Equation 33 and Taylor expansion on L(x(t + 1)) at x(t), we have that L(x(t + 1))= L(x(t)) 2L(x(t)) x(t + 1) x(t) + O(η2ρ2) = L(x(t)) η 2L(x(t)) L(x(t)) + ρ 2L(x(t)) L(x(t)) L(x(t))) 2 + O(ηρ2) . (34) J ANALYSIS FOR 1-SAM (PROOF OF THEOREM 5.4) The goal of this section is to prove the following theorem. Theorem 5.4. Let {x(t)} be the iterates of 1-SAM (Equation 13) and x(0) = xinit U, then under Setting 5.1, for almost every xinit, for all η and ρ such that (η +ρ) ln(1/ηρ) is sufficiently small, with probability at least 1 O(ρ) over the randomness of the algorithm, the dynamics of 1-SAM (Equation 13) can be split into two phases: Phase I (Theorem J.1): 1-SAM follows Gradient Flow with respect to L until entering an O(ηρ) neighborhood of the manifold Γ in O(ln(1/ρη)/η) steps; Phase II (Theorem J.2): 1-SAM tracks the solution of Equation 14, X, the Riemannian gradient flow with respect to Tr( 2L( )) in an O(ηρ) neighborhood of manifold Γ. Quantitatively, the approximation error between the iterates x and the corresponding limiting flow X is O(η1/2 + ρ), that is, x( T3/(ηρ2) ) X(T3) 2 = O(η1/2 + ρ). As mentioned in our proof setups in Appendix E, we will prove Theorem 5.4 under a more general (and weaker) condition, namely Condition E.1 and Assumption 3.2. The only usage of Setting 5.1 in the proof is Theorems 5.2 and E.2, which are restated below. Theorem 5.2. Loss L, set Γ and integer M defined in Setting 5.1 satisfy Assumption 3.2. Theorem E.2. Setting 5.1 implies Condition E.1. Condition E.1. Total loss L = 1 M PM k=1 Lk. For each k [M], Lk is C4, and there exists a (D 1)- dimensional C2-submanifold of RD, Γk, where for all x Γk, x is a global minimizer of Lk, Lk(x) = 0 and rank( 2Lk(x)) = 1. Moreover, Γ = M k=1Γk for Γ defined in Assumption 3.2. Analogous to the full-batch setting, we will split the trajectory into two phases. Theorem J.1 (Phase I). Let {x(t)} be the iterates defined by SAM (Equation 13) and x(0) = xinit U, then under Assumption 3.2 and E.1, for almost every xinit, there exists a constant T1, it holds for sufficiently small (η + ρ) ln 1/ηρ, we have with probability 1 O(ρ), there exists t T1 ln(1/ηρ)/η, such that x(t) Φ(x(t)) 2 = O(ηρ) and Φ(xinit) Φ(x(t)) 2 = O(η1/2 + ρ). Theorem J.1 shows that SAM will converges to an O(ηρ) neighborhood of the manifold without getting far away from Φ(x(0)), where we can perform a local analysis on the trajectory of Φ(x(t)). Under Assumptions 3.2 and E.1, we have Tr( 2Lk(x)) = λ1( 2Lk(x)) is differentiable for x Γi. Hence Tr( 2L(x)) = PM k=1 Tr( 2Lk(x)) is also differentiable and we have (14) is well defined for some finite time T2. Theorem J.2 (Phase II). Let {x(t)} be the iterates defined by SAM (Equation 13) under Assumptions 3.2 and E.1, assuming (1) x(0) Φ(x(0)) 2 = O(ηρ) and (2) Φ(xinit) Φ(x(0)) 2 = O(η1/2 + ρ), then for almost every x(0), for any T2 > 0 till which solution of (14) X exists, for sufficiently small (η +ρ) ln 1/(ηρ), we have with probability 1 O(ηρ), for all ηρ2t < T2, Φ(x(t)) X(ηρ2t) 2 = O(η1/2 + ρ) and x(t) Φ(x(t)) 2 = O(ηρ). Published as a conference paper at ICLR 2023 Combining Theorems E.2, J.1 and J.2, the proof of Theorem 5.4 is clear and we deferred it to Appendix J.3. Now we recall our notations for stochastic setting with batch size one. Notations for Stochastic Setting: Since Lk is rank-1 on Γk for each k [M], we can write it as Lk(x) = Λk(x)wk(x)w k (x) for any x Γk, where wk is a continuous function on Γk with pointwise unit norm. Given the loss function Lk, its gradient flow is denoted by mapping ϕk : RD [0, ) RD. Here, ϕk(x, τ) denotes the iterate at time τ of a gradient flow starting at x and is defined as the unique solution of ϕk(x, τ) = x R τ 0 Lk(ϕk(x, t))dt, x RD. We further define the limiting map Φk as Φk(x) = limτ ϕk(x, τ), that is, Φk(x) denotes the convergent point of the gradient flow starting from x. Similar to Definition 3.3, we define Uk = {x RD|Φ(x) exists and Φk(x) Γk} be the attraction set of Γi. We have that each Uk is open and Φk is C 2 on Uk by Lemma B.15 in Arora et al. (2022). In this section we will define K as {X(t) | t [0, T3]} where X is the solution of (14). We will denote h(K) in Lemma E.6 by h. Using Theorem D.3, we will assume the update is always well defined. J.1 PHASE I (PROOF OF THEOREM J.1) Proof of Theorem J.1. The proof consists of two steps. 1. Tracking Gradient Flow. By Lemma J.3, with probability 1 ρ2, there exists step t GF = O(1/η) such that x(t GF) Φ(x(t GF)) 2 h/4. Φ(x(t GF)) Φ(xinit) 2 = O(η1/2 + ρ). 2. Decreasing Loss. By Lemma J.7, with probability 1 O(ρ), there exists step t DEC = t GF + O(ln(1/ρ)/η) = O(ln(1/ρ)/η) such that L(x(t DEC)) 2 = O(ρ). Φ(x(t DEC)) Φ(xinit) 2 Φ(x(t DEC)) Φ(x(t GF)) 2 + Φ(x(t GF)) Φ(xinit) 2 = O(η1/2 + ρ). Then by Lemma J.12, with probability 1 O(ρ), there exists step t DEC2 = t DEC + O(ln(1/ηρ)/η) = O(ln(1/ηρ)/η), it holds that x(t DEC2) Φ(x(t DEC2)) 2 = O(ηρ). Φ(x(t DEC2)) Φ(xinit) 2 Φ(x(t DEC2)) Φ(x(t DEC)) 2 + Φ(x(t DEC)) Φ(xinit) 2 = O(η1/2 + ρ). Concluding, let T1 be the constant satisfying t DEC2 T1 ln(1/ηρ)/η, then we have for t = t DEC2 T1 ln(1/ηρ)/η such that x(t) Φ(x(t)) 2 = O(ηρ). Φ(x(t)) Φ(xinit) 2 = O(η1/2 + ρ). J.1.1 TRACKING GRADIENT FLOW Lemma J.3 shows that the iterates x(t) tracks gradient flow to an O(1) neighbor of Γ. Lemma J.3. Under condition of Theorem J.1, with probability 1 O(ρ2), there exists t GF = O(1/η), such that the iterate x(t GF) is O(1) close to the manifold Γ and Φ(x(t GF)) is O(η1/2 + ρ) is close to Φ(xinit). Quantitatively, L(x(t GF)) µh2 32 x(t GF) Φ(x(t GF)) h/4 , Φ(x(t GF)) Φ(xinit) = O(η1/2 + ρ) . Proof of Lemma J.3. Choose C = 1 Published as a conference paper at ICLR 2023 There exists T > 0, such that ϕ(xinit, T) Φ(xinit) 2 Ch/2 . x(t + 1) = x(t) η Lk(x(t) + ρ Lk (x(t)) Lk (x(t)) ) = x(t) η Lk(x(t)) + O(ηρ) . By Theorem L.1, let b(x) = L(x),p = η and ϵ = O(ρ), for sufficiently small η and ρ, the iterates x(t) tracks gradient flow ϕ(xinit, T) in O(1/η) steps in expectation, Quantitatively, with probability 1 ρ2, for t GF = T0 η , we have that x(t GF) ϕ(xinit, T0) 2 = O( p + ϵ) O(η1/2 + ρ) . This implies x(t GF) Kh, hence by Taylor Expansion on Φ, Φ(x(t GF)) Φ(xinit) 2 = Φ(x(t GF)) Φ(ϕ(xinit, T)) 2 O( x(t GF) ϕ(xinit, T) 2) O(η1/2 + ρ) . This implies x(t GF) Φ(x(t GF)) 2 x(t GF) ϕ(xinit, T0) 2 + ϕ(xinit, T0) Φ(xinit) 2 + Φ(xinit) Φ(x(t GF)) 2 Ch/2 + O(η1/2 + ρ) Ch h/4 . By Taylor Expansion, L(x(t GF)) ζ x(t GF) Φ(x(t GF)) 2 2/2 µh2 J.1.2 DECREASING LOSS Lemma J.4. Under condition of Theorem J.1, assuming x(t0) Kh/4 and for any t satisfying t0 t t0 + O(ln(1/ηρ)/η), max t0 τ t0+O(ln(1/ηρ)/η) L(x(τ)) µh2 16 , it holds that x(τ) Kh, t0 τ t. Moreover, we have that Φ(x(t)) Φ(x(t0)) = O((η + ρ) ln(1/ηρ)). Proof of Lemma J.4. We will prove by induction. For τ = t0, the result holds trivially. Suppose the result holds for t 1, then for any τ satisfying t0 τ t 1, by Lemmas F.1 and F.8, Φ(x(τ + 1)) Φ(x(τ)) ξηρ L (x(τ)) 2 + νηρ2 + ξη2 L (x(τ)) 2 2 + ξζ2η2ρ2 = O(η2 + ηρ) . Also by Lemma F.1, x(t) Φ(x(t)) 2 h/2 2, this implies, dist(K, x(t)) dist(K, x(t0)) + x(t0) Φ(x(t0)) 2 + Φ(x(t0)) Φ(x(t)) + Φ(x(t)) x(t) 0.99h + O(η2(t t GF)) = 0.99h + O(η ln(1/ηρ)) h . Published as a conference paper at ICLR 2023 Lemma J.5. Under condition of Theorem J.1, if x(τ) Kh, then we have that E[L(x(τ + 1))|x(τ)] L(x(τ)) ηµ 2 L(x(τ)) . Moreover it holds that, E[ln L(x(τ + 1))|x(τ)] ln E[L(x(τ + 1))|x(τ)] ln L(x(τ)) ηµ Proof of Lemma J.5. By Lemma F.8 and Taylor Expansion, E[L(x(τ + 1))|x(τ)] =E L x(τ) η Lk[x(τ) + ρ Lk (x(τ)) Lk (x(τ)) ] |x(τ) E L(x(τ)) η L (x(τ)) , Lk x(τ) + ρ Lk (x(τ)) 2 Lk[x(τ) + ρ Lk (x(τ)) Lk (x(τ)) ] 2 2 L(x(τ)) η L (x(τ)) 2 2 + ηρζ L (x(τ)) 2 + ζη2E[ Lk(x(τ)) 2 2] + ζ3η2ρ2 2 L (x(τ)) 2 2 2 L(x(τ)) . Lemma J.6. Under condition of Theorem J.1, assuming x(t0) Kh/4 and L(x(t0)) µh2 32 , then with probability 1 O(ρ), for any t satisfying t0 t t0 + O(ln(1/ηρ)/η), it holds that x(t) Kh. Moreover, we have that Φ(x(t)) Φ(x(t0)) = O((η + ρ) ln(1/ηρ)). Proof of Lemma J.6. By Uniform Bound and Lemma J.4, P( t0 t t0 + O(ln(1/ηρ)/η), L(x(t)) µh2 t0+O(ln(1/ηρ)/η) X t=t0 P(L(x(t)) µh2 16 and L(x(τ)) µh2 16 , t0 τ t 1) t0+O(ln(1/ηρ)/η) X t=t0 P(L(x(t)) µh2 16 and x(τ) Kh, t0 τ t 1) Consider each term, and applying uniform bound again, P(L(x(t)) µh2 16 and x(τ) Kh, t0 τ t 1) τ=t0 P(L(x(t)) µh2 16 and L(x(τ)) µh2 and t 1 τ τ + 1, µh2 16 > L(x(τ )) > µh2 32 and t 1 τ τ, x(τ ) Kh) . Then if we consider each term, we have that it is bounded by P(L(x(t)) µh2 16 and t 1 τ τ + 1, L(x(τ )) > µh2 and t 1 τ τ, x(τ ) Kh | L(x(τ)) µh2 Published as a conference paper at ICLR 2023 Define a coupled process L(τ + 1) = ln L(x(τ + 1)) and ( ln L(x(τ )), if L(τ 1) = ln L(x(τ 1)) ln( µh2 32 ), L(τ 1) ηµ/2, if otherwise. Then clearly P(L(x(t)) µh2 16 and t τ τ + 1, L(x(τ )) > µh2 and t τ τ, x(τ ) Kh | L(x(τ)) µh2 P( L(t) ln(µh2 Consider a fixed τ satisfying τ + 1 τ t. By Lemma J.5, we have that L(x(τ + 1)) L(x(τ )) ηµ/2. Hence L(t) + ηµt/2 is a super martingale. Further it holds that if L(x(τ 1)) ( µh2 L(x(τ 1)) L(x(τ )) = O( x(τ 1) x(τ ) ) = O(η) . Using the smoothness at log(x) at µh2 32 which is a positive constant, L(τ + 1) L(τ ) O(η) Cη . Here C is a constant independent of η. This implies L(x(τ + 1)) µh2 2 Now by Azuma-Hoeffding bound (Lemma K.4), we have that P( L(t) L(τ + 1) + (t τ 1)ηµ/2 > a) 2 exp( a2 8(t τ 1)(C + µ)2η2 ). With a = ln( µh2 16 L(τ+1)) + (t τ 1)ηµ/2 (ln 2 + (t τ 1)ηµ)/2, we have that P( L(t) > ln(µh2 16 )) 2 exp( (ln 2 + (t τ 1)ηµ)2 32(C + µ)2η2 ) 2 exp( ln 2(t τ 1)µ 8(C + µ)2η ) Hence we have P( t0 t t0 + O(ln(1/ηρ)/η), L(x(t)) µh2 O(2 exp( ln 2(t τ 1)µ 8(C + µ)2η ) ln2(1/ηρ)/η2) ρ. Hence with probability 1 ρ, L(x(t)) µh2 16 , t0 t t0 + O(ln(1/ηρ)/η), combining with Lemma J.4, we have completed our proof. Lemma J.7. Under condition of Theorem J.1, assuming there exists t GF such that L(x(t GF)) µh2 32 and x(t GF) Kh/4, then with probability 1 O(ρ), there exists t DEC = t GF+O(ln(1/ρ)/η), such that x(t DEC) is in O(ρ) neighbor of Γ, quantitatively, we have that L(x(t DEC)) 2 4ζρ . Moreover the movement of the projection of Φ(x( )) on the manifold is bounded, Φ(x(t GF)) Φ(x(t DEC)) 2 = O((η + ρ) ln(1/ρ)) . Published as a conference paper at ICLR 2023 Proof of Lemma J.7. For simplicity of writing, define T1 2 ln h2 256ρ3µ ηµ = O(ln(1/ρ)/η). By Lemma J.6, we may assume x(t) Kh for t GF t T1 + t GF. Define indicator function as A(t) = 1[ L (x(τ)) 4ζρ, t τ t GF] . By Lemma J.5, we have that, E[L(x(t + 1))A(t + 1)] E[L(x(t + 1))A(t)] (1 ηµ 2 )E[L(x(t))A(t)]. We can then conclude that with T2 = T1 + t GF, using Lemma F.2, 8µρ2EA(T2 + 1) E[L(x(T2 + 1))A(T2 + 1)] (1 ηµ 2 )T1L(x(t GF)) 8µρ3. EA(T2 + 1) ρ. This implies A(T2 + 1) = 0 with probability 1 O(ρ), which indicates the existence of t DEC. The second claim is a direct application of Lemma J.6. Lemma J.8 (A general version of Lemma 5.5). Under Assumption 3.2 and Condition E.1, for x Kh and p C, 2Lk(p) = Λk(p)wk(p)wk(p) , there exists s {1, 1}, Lk (x) Lk (x) = swk(p) + O( x p 2) . Further if |w k (x p)| x p 3/2 2 , then s = sign(w k (x p)). This implies Lk (x) Lk (x) (x p) sw k (x p) O( x p 2 2) w k (x p) 2 O( x p 3/2 2 ) . Proof of Lemma J.8. We will calculate the direction of Lk(x) Lk(x) using two different approximations and compare them to get our result. 1. According to Lemma F.4, Lk (x) Lk (x) = 2Lk(Φk(x))(x Φk(x)) 2Lk(Φk(x))(x Φk(x)) 2 + O( x Φk(x) 2). Suppose 2Lk(Φk(x)) = Λk(Φk(x))wk(Φk(x))wk(Φk(x)) , then Lk (x) Lk (x) = wk(Φk(x)) + O( x Φk(x) 2) As 2Lk(p) = Λk(p)wk(p)wk(p) , using Davis-Kahan Theorem K.3, we would have s { 1, 1}, such that wk(Φk(x)) swk(p) 2 ζ Φk(x) p 2. Lk (x) Lk (x) = swk(p) + O( Φk(x) p 2 + x p 2). According to Lemma F.1, we have x Φk(x) 2 Lk(x) 2 µ . This implies, Lk (x) Lk (x) = swk(p) + O( x p 2). (35) Equation 35 is our first statement. Published as a conference paper at ICLR 2023 2. By Taylor expansion at p, Lk (x) = Λk(x)wk(p)wk(p) (x p) + O(ν x p 2 2). That being said, when |w k (x p)| x p 3/2 2 , we have Lk (x) Λkwkw k (x p) 2 O( x p 2 2) . Lk (x) Λkwkw k (x p) 2 O( x p 2 2) Ω( x p 3/2 2 ). Concluding, Lk (x) Λkwkw k (x p) Λkwkw k (x p) 2 O( x p 1/2 2 ) Hence we have Lk (x) Lk (x) = sign(w k (x p))wk + O( x p 1/2 2 ) . (36) Comparing (35) and (36), we have s = sign(wk(p) (x p)) when |w k (x p)| x p 3/2 2 . Lemma J.9. Under condition of Theorem J.1, for any constant C > 0 independent of η, ρ, there exists constant C1 > C2 > 0 independent of η, ρ, if x(t) Kh and C1ηρ x(t) Φ(x(t)) Cρ, then we have that Ek[ x(t + 1) Φ(x(t + 1)) 2 | x(t)] x(t) Φ(x(t)) 2 C2ηρ . Proof of Lemma J.9. By Lemma F.2, x(t) Φ(x(t)) = O(ρ). Hence we have that by Taylor Expansion, x(t + 1) = x(t) η Lk x(t) + ρ Lk (x(t)) = x(t) η Lk (x(t)) ηρ 2Lk(x(t)) Lk (x(t)) Lk (x(t)) + O(ηρ2) = x(t) η Lk (x(t)) ηρΛkwkw k Lk (x(t)) Lk (x(t)) + O(ηρ2) . Here Λk, wk indicates Λk(Φ(x(t))), wk(Φ(x(t))). Notice that given x(t) Φ(x(t)) = O(ρ), by Lemma F.8, we have that Φ(x(t + 1)) Φ(x(t)) 2 = O(ηρ2), x(t + 1) x(t) 2 = O(ηρ). This implies x(t + 1) Kr. Further by Taylor Expansion, Lk(x(t)) = Λkwkw k (x(t) Φ(x(t))) + O(ρ2). By Lemma J.8, we have for some sk(t) { 1, 1}. w k Lk (x(t)) Lk (x(t)) = sk(t)wk + O( x(t) Φ(x(t)) 2) . We also have sk(t) = sign(w k (x(t) Φ(x(t)))) w k (x(t) Φ(x(t))) 2 x(t) Φ(x(t)) 3/2 2 . (37) Concluding, x(t + 1) Φ(x(t + 1)) =(x(t) Φ(x(t))) ηΛkwkw k (x(t) Φ(x(t))) ηρΛksk(t)wkw k wk + O(ηρ2). After we take square and expectation, E[ x(t + 1) Φ(x(t + 1)) 2 2 | x(t)] x(t) Φ(x(t)) 2 2 + 2η2 k=1 Λ2 k|w k (x(t) Φ(x(t)))|2 + 2η2ρ2 k=1 Λk|w k (x(t) Φ(x(t)))|2 2ηρ k=1 Λksk(t)w k (x(t) Φ(x(t))) + O(ηρ2 x(t) Φ(x(t)) + η2ρ3) . Published as a conference paper at ICLR 2023 We will then carefully examine each positive term, k=1 Λ2 k|w k (x(t) Φ(x(t)))|2 = 2Mη2(x(t) Φ(x(t))) 2L(x(t))2(x(t) Φ(x(t))) 2Mζη2 x(t) Φ(x(t)) 2 = O(η2ρ2) . k=1 Λ2 k 2ζ2η2ρ2 = O(η2ρ2) . This implies, E[ x(t + 1) Φ(x(t + 1)) 2 2 | x(t)] x(t) Φ(x(t)) 2 2 2ηρ k=1 Λksk(t)w k (x(t) Φ(x(t))) + O(ηρ2 x(t) Φ(x(t)) + η2ρ2) . We will now lower bound PM k=1 Λksk(t)w k (x(t) Φ(x(t))). By Equation 37, k=1 Λksk(t)w k (x(t) Φ(x(t))) k=1 Λk w k (x(t) Φ(x(t))) 2 2 k=1 Λk x(t) Φ(x(t)) 3/2 2 k=1 Λk w k (x(t) Φ(x(t))) 2 O( x(t) Φ(x(t)) 3/2 2 ) . For PM k=1 Λk w k (x(t) Φ(x(t))) 2, by Lemma Lemma F.4, k=1 Λk w k (x(t) Φ(x(t))) 2 k=1 Λ2 k w k (x(t) Φ(x(t))) 2 2 (x(t) Φ(x(t))) 2L(Φ(x(t)))2(x(t) Φ(x(t))) = 2L(Φ(x(t)))(x(t) Φ(x(t))) 2 µ Φ(Φ(x(t)))(x(t) Φ(x(t))) 2 µ x(t) Φ(x(t)) 2 O( x(t) Φ(x(t)) 2 2) . Concluding, we have that k=1 Λksk(t)w k (x(t) Φ(x(t))) µ x(t) Φ(x(t)) 2/2 . So E[ x(t + 1) Φ(x(t + 1)) 2 2 | x(t)] x(t) Φ(x(t)) 2 2 µηρ M x(t) Φ(x(t)) 2 + O(ηρ2 x(t) Φ(x(t)) + η2ρ2) ( x(t) Φ(x(t)) 2 C2ηρ)2 . The inequality holds if x(t) Φ(x(t)) 2 > C1ηρ. Finally by Jenson s Inequality, E[ x(t + 1) Φ(x(t + 1)) 2|x(t)] x(t) Φ(x(t)) 2 C2ηρ. Lemma J.10. Under condition of Theorem J.1, for any constant C > 0 independent of η, ρ, there exists constant C3 > 0 independent of η, ρ, if x(t) Kh and x(t) Φ(x(t)) Cρ, then we have that | x(t + 1) Φ(x(t + 1)) 2 x(t) Φ(x(t)) 2| C3ηρ . Published as a conference paper at ICLR 2023 Proof of Lemma J.10. This is a direct application of Lemma F.8. Lemma J.11. Under condition of Theorem J.1, assuming x(t0) Kh/2 and x(t0) Φ(x(t0)) f(η, ρ) for some fixed function f and f(η, ρ) Ω(ηρ ln2(1/ηρ)) O(ρ), then with probability 1 O(ρ), for any t satisfying t0 t t0 + O(ln(1/ηρ)/η), it holds that x(t) Φ(x(t) 2f(η, ρ). Moreover, we have that Φ(x(t)) Φ(x(t0)) = O((η + ρ) ln(1/ηρ)). Proof of Lemma J.11. By Lemma J.6, we have that x(t) Kh for any t satisfying that t0 t t0 + O(ln(1/ηρ)/η) and with probability 1 O(ρ) we will suppose this hold for the following deduction. By Uniform Bound, P( t0 t t0 + O(ln(1/ηρ)/η), x(t) Φ(x(t) 2f(η, ρ)) t0+O( ln(1/ηρ) t=t0 P( x(t) Φ(x(t)) 2f(η, ρ)) and x(τ) Φ(x(τ)) 2f(η, ρ), t0 τ t 1). Consider each term and apply Uniform bound again, P( x(t) Φ(x(t)) 2f(η, ρ)) and x(τ) Φ(x(τ)) 2f(η, ρ), t0 τ t 1) τ=t0 P( x(t) Φ(x(t)) 2f(η, ρ)) and x(τ) Φ(x(τ)) f(η, ρ), and f(η, ρ) x(τ ) Φ(x(τ )) 2f(η, ρ), τ + 1 τ t 1). Then if we consider each term, it is bounded by P( x(t) Φ(x(t)) 2f(η, ρ)) and f(η, ρ) x(τ ) Φ(x(τ )) 2f(η, ρ), τ + 1 τ t 1 | x(τ) Φ(x(τ)) f(η, ρ)). (38) Now let C be the positive constant satisfying 2f(η, ρ) Cρ, suppose C1, C2 are the constants corresponds to C in Lemma J.9 and C3 is the constant correspond to C in Lemma J.10. By definition C3 > C2. Define a coupled process y(τ + 1) = y(τ + 1) and y(τ ) = x(τ ) Φ(x(τ )) 2, if y(τ 1) = x(τ 1) Φ(x(τ 1)) 2 > f(η, ρ) y(τ 1) C2ηρ, if otherwise. Now clearly Equation 38 is bounded by P( y(t) 2f(η, ρ)). As E[ y(τ )] y(τ 1) C2ηρ by Lemma J.9 and y(τ ) y(τ 1) C3ηρ by Lemma J.10. This implies y(τ ) C2ηρτ is a super martingale. By Azuma-Hoeffding bound(Lemma K.4), we have P( y(t) y(τ + 1) C2ηρ(t τ 1) + h) 2 exp( h2 4(t τ 1)(C3 + C2)2η2ρ2 ). Choosing h = C2ηρ(t τ 1) x(τ + 1) Φ(x(τ + 1)) + 2f(η, ρ) P( y(t + 1) 2f(η, ρ)) 2 exp( (C2ηρ(t τ) x(τ + 1) Φ(x(τ + 1)) + 2f(η, ρ))2 8(t τ)(C3 + C2)2η2ρ2 ) 2 exp( (C2ηρ(t τ) + f(η, ρ)/2)2 4(t τ)(C3 + C2)2η2ρ2 ) 2 exp( C2f(η, ρ) 2(C3 + C2)2ηρ) η10ρ10. We then have P( t0 t t0 + O(ln(1/ηρ)/η), x(t) Φ(x(t) 2f(η, ρ)) ρ. Published as a conference paper at ICLR 2023 Lemma J.12. Under condition of Theorem J.1, assuming there exists t DEC such that x(t DEC) Kh/2 and L(x(t DEC)) 4ζρ, then with probability 1 O(ρ), there exists t DEC2 = t DEC + O(ln(1/ηρ)/η), such that x(t DEC2) Φ(t DEC2) O(ηρ). Furthermore, for any t satisfying t DEC2 t t DEC2 + Θ(ln(1/ηρ)/η), we have that Φ(x(t)) Φ(x(t DEC)) = O(ρ2 ln(1/ηρ)). Proof of Lemma J.12. We have that x(t) Kh (Lemma J.6) and x(t) Φ(x(t)) Cρ for some constant C (Lemma J.11) for any t satisfying that t DEC t t DEC + O(ln(1/ηρ)/η) with probability 1 O(ρ) and we will suppose this holds for the following deduction. The second statement then follows directly from Lemma F.8. Let C1, C2 be the constant in Lemma J.9 corresponding to C, For simplicity of writing, define T1 C ln( C C1ηρ2 ) C2η = O(ln(1/ηρ)/η). Define indicator function as A(t) = 1[ x(t) Φ(x(t)) C1ηρ, t τ t GF] . By Lemma J.9, we have that, E[ x(t + 1) Φ(x(t + 1)) A(t + 1)] E[ x(t + 1) Φ(x(t + 1)) A(t)] E[ x(t) Φ(x(t)) A(t)] C2ηρE[A(t)] E[ x(t) Φ(x(t)) A(t)](1 C2η We can then conclude that with T2 = T1 + t DEC, using Lemma F.2, C1ηρEA(T2 + 1) E[ x(T2 + 1) Φ(x(T2 + 1)) 2A(T2 + 1)] C )T1 x(t DEC) Φ(x(t DEC)) C1ηρ3. This implies A(T2 + 1) = 0 with probability 1 O(ρ), which indicates the existence of t DEC2. J.2 PHASE II (PROOF OF THEOREM J.2) Proof of Theorem J.2. We will inductively prove the following induction hypothesis P(t) holds with probability 1 O(η3ρ3t) for t T3/ηρ2 + 1, x(τ) Kh/2, τ t x(τ) Φ(x(τ)) 2 2 x(0) Φ(x(0)) 2 = O(ηρ), τ t Φ(x(τ)) X(ηρ2τ) = O(η1/2 + ρ), τ t P(0) holds trivially. Now suppose P(t) holds, then x(t + 1) Kh. By Lemma J.13, we have that with probability 1 O(η3ρ3), x(t + 1) Φ(x(t + 1)) 2 x(0) Φ(x(0)) 2 = O(ηρ). Now we have 2 x(0) Φ(x(0)) 2 = O(ηρ), τ t + 1. x(τ) Kh, τ t + 1 By Lemma J.14, it holds that Φ(x(τ + 1)) Φ(x(τ)) + ηρ2P Φ(x(τ)),Γ λ1 2Lkτ Φ(x(τ)) /2 O(ηρ3 + η2ρ2) . As Ekt P Φ(x(t)),Γ λ1 2Lkt Φ(x(t)) = P Φ(x(t)),Γ Tr( 2L(Φ(x(t)))). By Theorem L.1, let b(x) = Φ(x) Tr( 2L(x)), bk(x) = Φ(x)Tr( 2Lkt(x)), p = ηρ2 and ϵ = O(η + ρ), it holds that, with probability 1 O(η3ρ3), Φ(x(τ)) X(ηρ2τ) =O( Φ(x(0)) Φ(xinit) + T3ηρ2 + p ηρ2T3 log(2e T3/(η2ρ4)) + (ρ + η)T3) = O(η1/2 + ρ), τ t + 1 This implies x(t+1) X(ηρ2(t+1)) 2 x(t+1) Φ(x(t+1)) 2 + Φ(x(t+1)) X(ηρ2(t+1)) 2 = O(η1/2 + ρ) < h/2. Hence x(t + 1) Kh/2. Combining with P(t) holds with probability 1 O(η3ρ3t), we have that P(t + 1) holds with probability 1 O(η3ρ3(t + 1)). The induction is complete. Now P( T3/ηρ2 ) is equivalent to our theorem. Published as a conference paper at ICLR 2023 J.2.1 CONVERGENCE NEAR MANIFOLD Lemma J.13. Under condition of Theorem J.2, assuming x(t) Kh, t0 t t0 + O(1/ηρ2) and x(t0) Φ(x(t0)) f(η, ρ) for some fixed function f and f(η, ρ) Ω(ηρ ln2(1/ηρ)) O(ρ), then with probability 1 O(η3ρ3), for any t satisfying t0 t t0 + O(1/ηρ2), it holds that x(t) Φ(x(t)) 2f(η, ρ). Proof of Lemma J.13. The proof is almost identical to Lemma J.11 and is omitted. J.2.2 TRACKING RIEMANNIAN GRADIENT FLOW Lemma J.14. Under the condition of Theorem J.2, for any t satisfying that x(t) Kh and x(t) Φ(x(t)) = O(ηρ ln2(1/ηρ)). It holds that Φ(x(t + 1)) Φ(x(t)) + ηρ2P Φ(x(t)),Γ λ1 2Lkt Φ(x(t)) /2 O(ηρ3 + η2ρ2) . Proof of Lemma J.14. We will abbreviate kt by k in this proof. By Taylor Expansion, x(t + 1) = x(t) η Lk x(t) + ρ Lk (x(t)) = x(t) η Lk (x(t)) ηρ 2Lk (x(t)) Lk (x(t)) Lk (x(t)) ηρ2 2( Lk)[ Lk (x(t)) Lk (x(t)) , Lk (x(t)) Lk (x(t)) ]/2 + O(ηρ3). Now as x(t) Φ(x(t)) 2 = O(ηρ), by Lemma F.8, it implies x(t + 1) x(t) 2 = O(ηρ) . Then we have Φ(x(t + 1)) Φ(x(t)) Φ(x(t))(x(t + 1) x(t)) 2 ξ x(t + 1) x(t) 2 2 = O(η2ρ2). Using Lemma F.6, we have η Φ(x(t)) Lk (x(t)) 2 = O(η x(t) Φ(x(t)) 2 2) = O(η3ρ2 + ηρ4), ηρ Φ(x(t)) 2Lk (x(t)) Lk (x(t)) Lk (x(t)) 2 = O(ηρ x(t) Φ(x(t)) 2) = O(η2ρ2 + ηρ3). Φ(x(t + 1)) Φ(x(t)) + ηρ2 Φ(x(t)) 2( Lk)[ Lk (x(t)) Lk (x(t)) , Lk (x(t)) Lk (x(t)) ]/2 2 = O(η2ρ2 + ηρ3). Notice finally that by Lemma J.8, Φ(x(t)) 2( Lk)[ Lk (x(t)) Lk (x(t)) , Lk (x(t)) Lk (x(t)) ] = Φ(Φ(x(t))) 2( Lk)[wk, wk] + O( x(t) Φ(x(t)) 2) =P Φ(x(t)),Γ (λ1( 2Lk(Φ(x(t))))) + O( x(t) Φ(x(t)) 2). Hence we have Φ(x(t + 1)) Φ(x(t)) = ηρ2P Φ(x(t)),Γ λ1 2Lkt Φ(x(t)) /2 + O(η2ρ2 + ηρ3) This completes the proof. Published as a conference paper at ICLR 2023 J.3 PROOF OF THEOREM 5.4 Proof of Theorem 5.4. By Theorem J.1, there exists constant T1 independent of η, ρ, such that there exists t PHASE T1 ln(1/ηρ)/η, with probability 1 O(ρ), it holds that x(t PHASE) Φ(x(t PHASE)) 2 = O(ηρ). Φ(x(t PHASE)) Φ(xinit) = O(η1/2 + ρ) Hence by Theorem J.2, if we consider a translated process with x (t) = x(t + t PHASE), we would have for any T3 such that the solution X of Equation 14 is well defined, we have that for t = T3 Φ(x (t)) X(ηρ2t) 2 = O(η ln(1/ρ)) . This implies for t satisfying X(ηρ2(t t PHASE)) is well-defined, Φ(x(t)) X(ηρ2(t t PHASE)) 2 = O(η1/2 + ρ). Finally, as X(ηρ2(t t PHASE)) X(ηρ2t) 2 = O(ηρ2t PHASE) = O(ρ ln(1/ηρ)) = O(ρ). We have that Φ(x(t)) X(ηρ2t) 2 = O(η1/2 + ρ). We also have x(t) Φ(x(t)) 2 = O(ηρ). by Theorem J.2. J.4 PROOFS OF COROLLARIES 5.6 AND 5.7 Proof of Corollary 5.6. We will do Taylor expansion on Ek[LMax k,ρ ](x). By Theorem J.1 and J.2, we have x( T3/ηρ2 ) X(T3) 2 = O(η1/2 + ρ) and Φ(x( T3/ηρ2 )) x( T3/ηρ2 ) 2 = O(η1/2 + ρ). For convenience, we denote x( T3/ηρ2 ) by x. Ek[RMax k,ρ ](x) = max v 1 Ek[ρv Lk(x) + ρ2v 2Lk(x)v/2] + O(ρ3) Since max v 1 |v Lk(x)| = O( x Φ(x) ) = O(η1/2 + ρ), it holds that, Ek[RMax k,ρ ](x) = ρ2Ek[ max v 1 v 2L(x)v/2] + O (η1/4 + ρ1/4)ρ2 = ρ2Ek max v 1[v 2L(X(T3))v/2] + O (η1/4 + ρ1/4)ρ2 = ρ2Tr(X(T3))/2 + O (η1/4 + ρ1/4)ρ2 Proof of Corollary 5.7. We choose Tϵ such that X(Tϵ) is sufficiently close to X( ), such that Tr(X(Tϵ)) Tr(X( )) + ϵ/2. By corollary 5.6 (let T3 = Tϵ), we have for all ρ, η such that (η + ρ) ln(1/ηρ) is sufficiently small, Ek[RMax k,ρ ](x( Tϵ/(ηρ2) )) ρ2Tr(X(T))/2 2 o(1). This further implies Ek[RMax k,ρ ](x( Tϵ/(ηρ2) )) ρ2Tr(X( ))/2 2 ϵρ2/2 + o(1). We also have L(x( Tϵ/(ηρ2) )) infx U L(x) = o(1). Then we can leverage Theorems G.6 and G.14 to get the desired bound. J.5 OTHER OMITTED PROOFS FOR 1-SAM We will use ℓ (y, yk) and ℓ (y, yk) to denote dℓ(y ,yk) dy |y =y and d2ℓ(y ,yk) dy 2 |y =y. Lemma J.15. Under Setting 5.1, fix k [M], for any p satisfying ℓ(fk(p), yk) = 0, we have that 2Lk(p) = ℓ (fk(p), yk) fk(p)( fk(p)) . Published as a conference paper at ICLR 2023 Proof of Lemma J.15. ℓ(fk(p), yk) = 0 implies ℓ (fk(p), yk) = 0. Then by Taylor Expansion, 2Lk(p) = 2 pℓ(fk(p), yk) = p[ℓ (fk(p), yk) fk(p)] = ℓ (fk(p), yk) fk(p) fk(p) + ℓ (fk(p), yk) 2fk(p) = ℓ (fk(p), yk) fk(p) fk(p) . This concludes the proof. Proof of Lemma 5.5. By Lemma J.15, as L(p) = 1 M PM k=1 Lk(p), we have ( y )2 |y =fk(x) fk(p) fk(p) . By definition of Γ in Setting 5.1, we have for any p Γ, { fk(p)}n k=1 are linearly independent, which implies that fk(p) = 0 for any p Γ. For any p Γ, as fk(p) fk(p) is well defined and continuous at p, there exists a open ball V containing p such that for any x V , fk(x) 2 C1 > 0 and [ fk(x) fk(x) ] 2 C2 for some constants C1 and C2. Suppose Lk(x) = 0, then as by Taylor Expansion, Lk(x) = ℓ (fk(x), yk) fk(x) . We have Lk(x) Lk(x) = fk(x) fk(x) = fk(p) fk(p) + C2 x p , which completes the proof. We note that the alignment result in Lemma 5.5 is not directly used in our proof. Instead, we use its generalized version Lemma J.8 which holds under holds under a more general condition than Setting 5.1, namely Condition E.1. K TECHNICAL LEMMAS Lemma K.1 (Corollary 4.3.15 in Horn et al. (2012)). Let Σ, ˆΣ RD D be symmetric and non-negative with eigenvalues λ1 ... λD and ˆλ1 ... ˆλD, then for any i, |ˆλi λi| Σ ˆΣ 2 Definition K.2 (Unitary invariant norms). A matrix norm on the space of matrices in Rp d is unitary invariant if for any matrix K Rp d, UKW = K for any unitary matrices U Rp p, W Rd d. Theorem K.3. [Davis-Kahan sin(θ) theorem (Davis et al., 1970)] Let Σ, ˆΣ Rp p be symmetric, with eigenvalues λ1 . . . λp and ˆλ1 . . . ˆλp respectively. Fix 1 r s p, let d s r+1 and let V = (vr, vr+1, . . . , vs) Rp d and ˆV = (ˆvr, ˆvr+1, . . . , ˆvs) Rp d have orthonormal columns satisfying Σvj = λjvj and ˆΣˆvj = ˆλjˆvj for j = r, r + 1, . . . , s. Define min n max{0, λs ˆλs+1}, max{0, ˆλr 1 λr} o , where ˆλ0 and ˆλp+1 , we have for any unitary invariant norm , sin Θ( ˆV , V ) ˆΣ Σ . Here Θ( ˆV , V ) Rd d, with Θ( ˆV , V )j,j = arccos σj for any j [d] and Θ( ˆV , V )i,j = 0 for all i = j [d]. σ1 σ2 σd denotes the singular values of ˆV V. [sin Θ]ij is defined as sin(Θij). Lemma K.4 (Azuma-Hoeffding Bound). Suppose {Zn}n N is a super-martingale, suppose α Zi+1 Zi β, then for all n > 0, a > 0, we have P(Zn Z0 a) 2 exp( a2/(2n(α + β)2)) Lemma K.5 (Azuma-Hoeffding Bound, Vector Form, Hayes (2003)). Suppose {Zn}n N is a RD-valued martingale, suppose Zi+1 Zi 2 σ, then for all n > 0, a > 0, we have P( Zn Z0 2 σ(1 + a)) 2 exp(1 a2/2n). In other words, for any 0 < δ < 1, with probability at least 1 δ, we have that Published as a conference paper at ICLR 2023 Lemma K.6 (Discrete Gronwall Inequality, Borkar (2009)). Let {x(t)}t N be a sequence of nonnegative real numbers, {an}n N be a sequence of positive real numbers and C, L > 0 scalars such that for all n, n=0 anx(n). Then for Tt = Pt n=0 an, it holds that x(t + 1) Ce LTt. Lemma K.7 (Magnus (1985)). Let A : RD RD D be any C1 symmetric matrix function and x RD satisfying λ1(A(x )) > λ2(A(x )) and v1 be the top eigenvector of A(x ). It holds that λ1(A(x))|x=x = (v 1 A(x)v1)|x=x . We then present some of the technical lemmas we required to prove Lemma H.5. Lemma K.8. If 0 < c < b a a2+2b2 2(1 cb) a2+b2 2 ca cb, then a > 1 Proof of Lemma K.8. Notice that 2(1 cb) cb2 + ca2 2 cb ca cb2 + ca2 b2 , we have 1 > 1 cb > a The above inequality implies a 1 2b. As c < b a Lemma K.9. When 0 < a < b, 0 < c < b a b2 , we have cb2 + ca2(2 cb 2 3ca) (1 cb) c(a2 + b2) 2 ca cb ca2(1 2a2 + b2)2 ca cb (a2 + b2) cb2 Proof of Lemma K.9. Equivalently, we are going to prove (1 cb)b2 1 2 ca cb 1 2 cb + a2 1 cb 2 ca cb + a2(1 2a2 + b2)2 ca cb (a2 + b2) a2(2 cb 2 Further simplifying, we only need to prove (2 cb)(2 ca cb) + a2 1 cb 2 ca cb 1 2(a2 + b2)(2 ca cb) We have the following auxiliary inequalities, (1 cb)b > a 1 cb 2 ca cb 1 b + b a (1 cb)b 1 a = ab a2 + b2 a2 Published as a conference paper at ICLR 2023 Using the above auxiliary inequalities we have (1 cb)cab2 (2 cb)(2 ca cb) + a2 1 cb 2 ca cb 1 2(a2 + b2)(2 ca cb) ca2b (2 cb)(2 ca cb) + 1 1 2(2 ca cb) a2(1 cb) ca2b (2 cb)(2 ca cb) + ca2(a + b)(1 cb) 2(2 ca cb) 1 ca2b (2 cb)(2 ca cb) + ca2b(1 cb) 2(2 ca cb) 1 1 (2 cb)2 + 1 cb 2(2 cb) 1 3(1 cb)(2 cb) + 6 2(2 cb)2 (cb)2 cb + 4 0 Lemma K.10. When 0 < a < b, 0 < c < b a a2+2b2 2(1 cb) a2+b2 2 ca cb, we have cb2 + ca2(2 cb 2 3ca) (1 cb)cb2 ca2(1 2a2 + b2) 1 Proof of Lemma K.10. Equivalently, we are going to prove, cb3 + a2(2 cb 2 2 cb + a2( 1 cb3 + a2(1 cb 2 We have the auxiliary inequality 1 2 cb > 1 cb3 + a2(1 cb 2 cb3 + a2(1 cb 2 1 Case 1, If 3b3 3a3 0, then 2 Case 2, If 3b3 3a3 > 0, then 3a3) (b2 a2)2 3a3) (b a)(b + a)2 2(b3 ba2) (b a)(b + a)2 b3 (b a)(2b(a + b) (a + b)2) b3 (b a)2(b + a) b3 3 Using Lemma K.8,a > b 2,(b a)2(b + a) = (b2 a2)(b a) b2(b a) b3 Published as a conference paper at ICLR 2023 Lemma K.11. When 0 a b, 0 c b a b2 , b2 a q a2+2b2 2(1 cb) a2+b2 2 ca cb, we have cb2 + ca2(2 cb 2 2a2)(1 cb) cb2 Proof of Lemma K.11. Define F(a) a2(2 cb 2 S(c, b) {a|0 a b, 0 < c b a 2(1 cb) a2 + b2 amin(c, b) inf S(c, b) amax(c, b) sup S(c, b) b cb2 Consider d F(a) da = 2a(2 cb 2 2a2)(1 cb) a2 s 1 cb b2 + 1 da2 = 2(2 cb 2 1 cb b2 + 1 1 cb b2 + 1 4 2cb 4ca 3a 1 cb b2 + 1 Define u cb, v a b , then u + v 1. d2F(a) da2 4 2u 4uv 3 4 2u 4u(1 u) 3 1 2 + 1 (1 u)2 4u2 6u + 4 3 1 u (1 u) q 2 1 u,we have da2 4u2 6u + 4 3(1 u) = 4u2 + 1 3u > 0 The above inequality shows that F(a) is convex w.r.t to a for amin(c, b) a amax(c, b). Hence F(a) max (F(amin(c, b)), F(amax(c, b))). Below we use amin, amax as shorthands for amin(c, b),amax(c, b). For F(amin), we have amin q 2(1 cb) = a2 min+b2 2 camin cb. This implies 2a2 min)(1 cb) = (1 cb) (a2 min + b2) 2 camin cb + a2 min(1 2a2 min + b2)2 camin cb (a2 min + b2) Hence using Lemma K.9, F(amin) = a2 min(2 cb 2 3camin) (1 cb) c(a2 min + b2) 2 camin cb ca2 min(1 2a2 min + b2)2 camin cb (a2 min + b2) Published as a conference paper at ICLR 2023 For F(amax), we know that amax must satisfy at least of the following three equalities and we discuss three cases one by one. 2(1 cb) = a2 max+b2 2 camax cb, in this case we simply redo the calculation in Part 1. 2. b2 = amax q 2(1 cb) . This implies 2a2max)(1 cb) = (1 cb)b2 + a2 max(1 2a2 max + b2) 1 Hence using Lemma K.10, F(amax) = a2 max(2 cb 2 3camax) (1 cb)cb2 ca2 max(1 2a2 max + b2) 1 3. cb2 = b amax. Define v amax b , cb = 1 v. Note that 1 cb = amax b and b2 amax q b(a2max+2b2) 2amax . These imply a3 max + 2amaxb2 2b3 0 amax < 0.9b. This implies v 0.9. By Lemma K.8, 0.5 v. As v [0.5, 0.9], it holds that v(1 + v) + 1 1 + v 2 This implies v2(2 (1 v) 2 3(1 v)v) 2v 2 )v v 1 + v = 1 2 cb 1 F(amax) = a2 max(2 cb 2 3camax) 2amax 2a2max)(1 cb) v2(2 (1 v) 2 3(1 v)v) 2v b2( 1 2 cb 1) . In conclusion, it holds that, F(a) max (F(amin(c, b)), F(amax(c, b))) L OMITTED PROOFS ON CONTINUOUS APPROXIMATION In this section we give a general approximation result (Theorem L.1) between a continuous-time flow (Equation 39) and a discrete-time (stochastic) iterates (Equation 40) in some compact subset of RD, denoted by K. This result is used multiple times in our analysis for full-batch SAM and 1-SAM. 7 Let b : K RD is a C1-lipschitz function, that is, x, x K, it holds that b(x) b(x ) 2 C1 x x 2. Let bk be mappings from K to RD for k [M] satisfying that b(x) = 1 M PM k=1 bk(x) for all x K. We consider the continuous-time flow X : [0, T] K, which is the unique solution of d X(τ) = b(X(τ))dτ. (39) and the discrete-time iterate {x(t)}t N which approximately satisfy x(t + 1) x(t) + pbkt(x(t)), (40) 7Though we believe this approximation result is folklore, we cannot find a reference under the exact setting as ours. For completeness, we provide a quick proof in this section. Published as a conference paper at ICLR 2023 where kt is independently sampled from uniform distribution over [M] for each t N and x(t) is a deterministic function of k0, . . . , kt 1. We use Ft to denote the σ-algebra generated by k0, . . . , kt 1 and F to denote the filtration (Ft)t N. Thus x(t) is adapted to filtration F . Note b is undefined outside K, thus in the analysis we only consider the process stopped immediately leaving K, that is, x K(t) x(min(t, t K)), where t K {t N | x(t ) / K}. If x(t) is in K for all t 0, then t K = . It is easy to verify that t K is a stopping time with respect to the filtration F . For convenience, we denote XK(τ) = X(min(τ, pt K)) as the stopped continuous counterpart of x K. Theorem L.1. Suppose there exist constants C2, ϵ, ϵ > 0 satisfying that 1. bk(x) 2 C2, for any x K and k [M]; 2. bk(x) b(x) 2 C3, for any x K and k [M]; bkt(x(t)) x(t+1) x(t) p 2 ϵ, for all t. Then for any integer 0 k T/p and 0 < δ < 1, with probability at least 1 δ, it holds that max 0 t T/p x K(t) XK(pt) Hp,δe C1T , where Hp,δ x(0) X(0) 2 + C1C2Tp + 2C3 q p T log 2e T Proof of Theorem L.1. Summing up Equation 39 and Equation 40, for any t t K, we have that X(pt) X(0) = Z pt τ=0 b(X(τ))dτ, (41) x(t) x(0) = t =0 x(t + 1) x(t ) (42) Denote x(t) X(pt) 2 by Et, we have that for t t K, τ=0 b(X(τ))dτ t =0 (x(t + 1) x(t )) τ=0 b(X(τ))dτ p t =0 b(X(pt )) 2 | {z } (A) t =0 b(X(pt )) p t =0 b(x(t )) 2 | {z } (B) t =0 b(x(t )) p t =0 bkt (x(t )) 2 | {z } (C) t =0 bkt (x(t )) t =0 (x(t + 1) x(t )) 2 | {z } (D) Below we will proceed by bounding the four terms (A), (B), (C) and (D) in Equation 43. 1. Note that for any 0 τ τ T, we have that X(τ) X(τ ) 2 = s=τ b(X(s))ds s=τ b(X(s)) 2 ds (τ τ)C2. Thus, by C1-lipschitzness of b, τ=0 b(X(τ)) b(X( τ/p p))dτ 2 Z pt τ=0 b(X(τ)) b(X( τ/p p)) 2 dτ C1C2p2t C1C2p T. 2. By definition of Et and C1-lipschitzness of b, we have that (B) C1p Pt 1 t =0 Et . Published as a conference paper at ICLR 2023 3. We claim that for any 0 < δ < 1, we have that for probability at least 1 δ, it holds that 2p T log 2e T Below we prove our claim. We denote p Pmin(t,t K) 1 t =0 b(x(t )) p Pmin(t,t K) 1 t =0 bkt (x(t )) by St, which is a martingale with respect to filtration F , since t K is a stopping time. Note St St+1 2 max k [M],x K b(x) bk(x) 2 C3, by Azuma-Hoeffding s inequality (vector form, Lemma K.5), it holds that for any 0 t T/p and 0 δ 1, with probability at least 1 δ, Applying an union bound on the above inequality over t = 0, . . . , T/p 1, we conclude that with probability at least 1 δ, (C) 2C3p q 2T/p log 2e T 2Tp log 2e T δp . 4. We have that bkt (x(t )) x(t + 1) x(t ) Combining the above upper bounds for (A), (B), (C) and (D), we conclude that for any 0 t min(T/p, t K), Et Hp,δ + C1p t =0 Et . (45) Applying the discrete gronwall inequality (Lemma K.6) on Equation 45, we have that Et Hp,δe C1pt Hp,δe C1T , which completes the proof. Corollary L.2. If min0 τ T dist(X(τ), RD \ K) > Hp,δe C1T , then with probability at least 1 δ, t K > T/p and therefore max 0 t T/p x(t) X(pt) Hp,δe C1T . Proof of Corollary L.2. By Theorem L.1, we know with probability at least 1 δ, we have that max 0 t T/p x K(t) XK(pt) Hp,δe C1T . Therefore dist(x K(t), RD \K) dist(XK(pt), RD \K) dist(XK(pt), x K(t)) > 0 for any 0 t T/p, which implies x K(t) / RD \ K, or equivalently, x K(t) K. Thus we conclude that t K T/p . Corollary L.3. Suppose M = 1 and there exist constants C2, ϵ > 0 satisfying that 1. b(x) 2 C2 for any x K; b(x) x(t+1) x(t) p ϵ, for all x K. Then for any k N such that kp T, it holds that max 0 t T/p x K(t) XK(pt) Hpe C1T , where Hp x(0) X(0) 2 + C1C2Tp + ϵT. Therefore, similar to Corollary L.2, if min0 τ T dist(X(τ), RD \ K) > Hpe C1T , then it holds that t K > T/p and that max 0 t T/p x(t) X(pt) Hp,δe C1T . Published as a conference paper at ICLR 2023 Proof of Corollary L.3. For any δ (0, 1], choosing C3 = 0 and by Theorem L.1, we have that P max 0 t T/p x K(t) XK(pt) Hpe C1T 1 δ . Since δ can be any number in (0, 1], the above probability is exactly 1. We end this section with a summary of applications of Theorem L.1 and corollary L.3 in our proofs (Table 2). Setting p bk ϵ Full-batch SAM, Phase I (Lemma I.4) η L( ) ρ Full-batch SAM, Phase II (Theorem I.3) ηρ2 Φ( ) λ1( 2L( ))/2 ρ + η 1-SAM, Phase I (Lemma J.3) η Lk( ) ρ 1-SAM, Phase II (Theorem J.2) ηρ2 Φ( ) Tr( 2Lk( ))/2 ρ + η Table 2: Summary of applications of Theorem L.1 and corollary L.3 in our analysis