# improving_convergence_and_generalization_using_parameter_symmetries__7289e21e.pdf Published as a conference paper at ICLR 2024 IMPROVING CONVERGENCE AND GENERALIZATION USING PARAMETER SYMMETRIES Bo Zhao University of California San Diego bozhao@ucsd.edu Robert M. Gower Flatiron Institute rgower@flatironinstitute.org Robin Walters Northeastern University r.walters@northeastern.edu Rose Yu University of California San Diego roseyu@ucsd.edu In many neural networks, different values of the parameters may result in the same loss value. Parameter space symmetries are loss-invariant transformations that change the model parameters. Teleportation applies such transformations to accelerate optimization. However, the exact mechanism behind this algorithm s success is not well understood. In this paper, we show that teleportation not only speeds up optimization in the short-term, but gives overall faster time to convergence. Additionally, teleporting to minima with different curvatures improves generalization, which suggests a connection between the curvature of the minimum and generalization ability. Finally, we show that integrating teleportation into a wide range of optimization algorithms and optimization-based meta-learning improves convergence. Our results showcase the versatility of teleportation and demonstrate the potential of incorporating symmetry in optimization. 1 INTRODUCTION Given a deep neural network architecture and a dataset, there may be multiple points in the parameter space that correspond to the same loss value. Despite having the same loss, the gradients and learning dynamics originating from these points can be very different (Kunin et al., 2021; Van Laarhoven, 2017; Grigsby et al., 2022). Parameter space symmetries, which are transformations of the parameters that leave the loss function invariant, allow us to teleport between points in the parameter space on the same level set of the loss function (Armenta et al., 2023). In particular, teleporting to a steeper point in the loss landscape leads to faster optimization. Despite the empirical evidence, the exact mechanism of how teleportation improves convergence in optimizing non-convex objectives remains elusive. Previous work shows that gradient increases momentarily after a teleportation, but could not show that this results in overall faster convergence (Zhao et al., 2022). In this paper, we provide theoretical guarantees on the convergence rate. In particular, we show that stochastic gradient descent (SGD) with teleportation converges to a basin of stationary points, where every point reachable by teleportation is also stationary. We also provide conditions under which one teleportation guarantees optimality of the entire gradient flow trajectory. Previous applications of teleportation are limited to accelerating optimization. The second part of this paper explores a different objective improving generalization. We relate properties of minima to their generalization ability and optimize them using teleportation. We empirically verify that certain sharpness metrics are correlated with generalization (Keskar et al., 2017), although teleporting towards flatter regions has negligible effects on the validation loss. Additionally, we hypothesize that generalization also depends on the curvature of minima. For fully connected networks, we derive an explicit expression for estimating curvatures and show that teleporting towards larger curvatures improves the model s generalizability. To demonstrate the wide applicability of parameter space symmetry, we expand teleportation to standard optimization algorithms beyond SGD, including momentum, Ada Grad, RMSProp, and Published as a conference paper at ICLR 2024 Adam. Experimentally, teleportation improves the convergence speed for these algorithms. Inspired by conditional programming and optimization-based meta-learning (Andrychowicz et al., 2016), we also propose a meta-optimizer to learn where to move parameters in a loss level set. This approach avoids the computation cost of optimization on group manifolds and improves upon existing metalearning methods that are restricted to local updates. The convergence speedup, applications in improving generalization, and the ability to integrate with different optimizers demonstrate the potential of improving optimization using symmetry. In summary, our main contributions are: theoretical guarantees that teleportation accelerates the convergence rate of SGD; quantifying the curvature of a minimum and evidence of its correlation with generalization; a teleportation-based algorithm to improve generalization; various optimization algorithms with integrated teleportation including momentum, Ada Grad, and optimization-based meta-learning. 2 RELATED WORK Parameter space symmetry. Continuous symmetries have been identified in the parameter space of various architectures, including homogeneous activations (Badrinarayanan et al., 2015; Du et al., 2018), radial rescaling activations (Ganev et al., 2022), and softmax and batchnorm functions (Kunin et al., 2021). Permutation symmetry has been linked to the structure of minima (S ims ek et al., 2021; Entezari et al., 2022). Quiver representation theory provides a more general framework for symmetries in neural networks with pointwise (Armenta & Jodoin, 2021) and rescaling activations (Ganev & Walters, 2022). A new class of nonlinear and data-dependent symmetries are identified in (Zhao et al., 2023). Since symmetry defines transformations of parameters within a level set of the loss function, these works are the basis of the teleportation method discussed in our paper. Knowledge of parameter space symmetry motivates new optimization methods. One line of work seeks algorithms that are invariant to symmetry transformations (Neyshabur et al., 2015; Meng et al., 2019). Others search in the orbit for parameters that can be optimized faster (Armenta et al., 2023; Zhao et al., 2022). We build on the latter by providing theoretical analysis on the improvement of the convergence rate and by augmenting the teleportation objective to improve generalization. Initializations and restarts. Teleportation before training changes the initialization of parameters, which is known to affect the training dynamics. For example, imbalance between layers at initialization affects the convergence of gradient flows in two-layer models (Tarmoun et al., 2021). Different initializations, among other sources of variance, also lead to different model performance after convergence (Dodge et al., 2020; Bouthillier et al., 2021; Ramasinghe et al., 2022). In addition to initialization, teleportation allows changes in landscape multiple times throughout the training. Teleportation during training re-initializes the parameters to a point with the same loss. Its effect can resemble warm restart (Loshchilov & Hutter, 2017), which encourages parameters to move to more stable regions by periodically increasing the learning rate. Compared to restarts, teleportation leads to smaller temporary increase in loss and provides more control of where to move the parameters. Sharpness of minima and generalization. The sharpness of minima has been linked to the generalization ability of models both empirically and theoretically (Hochreiter & Schmidhuber, 1997; Keskar et al., 2017; Petzka et al., 2021; Ding et al., 2022; Zhou et al., 2020), which motivates optimization methods that find flatter minima (Chaudhari et al., 2017; Foret et al., 2021; Kwon et al., 2021; Kim et al., 2022). We employ teleportation to search for flatter points along the loss level sets. The sharpness of a minimum is often defined using properties of the Hessian of the loss function, such as the number of small eigenvalues (Keskar et al., 2017; Chaudhari et al., 2017; Sagun et al., 2017) or the product of the top k eigenvalues (Wu et al., 2017). Alternatively, sharpness can be characterized by the maximum loss within a neighborhood of a minimum (Keskar et al., 2017; Foret et al., 2021; Kim et al., 2022) or approximated by the growth in the loss curve averaged over random directions (Izmailov et al., 2018). The sharpness of minima does not always capture generalization (Dinh et al., 2017) (Andriushchenko et al., 2023). Some reparametrizations do not affect generalization but can lead to minima with different sharpness. Published as a conference paper at ICLR 2024 3 THEORETICAL GUARANTEES FOR IMPROVING OPTIMIZATION In this section, we provide a theoretical analysis of teleportation. We show that with teleportation, SGD converges to a basin of stationary points. Building on its relation to Newton s method, teleportation leads to a mixture of linear and quadratic convergence. Lastly, in certain loss functions, one teleportation guarantees optimality of the entire gradient flow trajectory. Symmetry Teleportation. We briefly review the symmetry teleportation algorithm (Zhao et al., 2022), which searches for steeper points in a loss level set to accelerate gradient descent. Consider the optimization problem w = arg min w Rd L(w), L(w) def = Eξ D [L(w, ξ)] where D is the data distribution, ξ is data sampled from D, L the loss, w the parameters of the model, and Rd the parameter space. Let G be a group acting on the parameter space, such that L(w) = L(g w), g G, w Rd. Symmetry teleportation uses gradient ascent to find the group element g that maximizes the magnitude of the gradient, and applies g to the parameters while leaving the loss value unchanged: w =g w, g = argmax g G L(g w) 2. 3.1 TELEPORTATION AND SGD At each iteration t N+ in SGD, we choose a group element gt G and use teleportation before each gradient step as follows wt+1 = gt wt η L(gt wt, ξt). (1) Here η is a learning rate, L(wt, ξt) is the gradient of L(wt, ξt) with respect to the parameters w, and ξt D is a mini-batch of data sampled i.i.d from the data distribution at each iteration. By choosing the group element that maximizes the gradient norm, we show in the following theorem that the iterates in equation 1 converge to a basin of stationary points, where all points that can be reached via teleportation are also stationary points (visualized in Figure 1). Theorem 3.1. (Smooth non-convex) Let L(w, ξ) be β smooth and let σ2 def = L(w ) E h inf w L(w, ξ) i . Consider the iterates wt given by equation 1 where gt arg max g G L(g wt) 2, which we assume exists. 1 If η = 1 β T 1 then min t=0,...,T 1E max g G L(g wt) 2 T 1E L(w0) L(w ) + βσ2 where the expectation is the total expectation with respect to the data ξt for t = 0, . . . , T 1. Figure 1: With teleportation, SGD converges to a basin where all points on the level set are stationary points. This theorem is an improvement over vanilla SGD, for which we would have instead that min t=0,...,T 1E L(wt) 2 2β T 1E L(w0) L(w ) + βσ2 1For instance when G is compact and L(g wt) is continuous over G, or when the gradient is a coercive function and G is bounded. Published as a conference paper at ICLR 2024 The above only guarantees that there exists a single point wt for which the gradient norm will eventually be small. In contrast, our result in equation 2 guarantees that for all points over the orbit {g wt : g G}, the gradient norm will be small. For strictly convex loss functions, maxg G L(g w) 2 is non-decreasing with L(w). In this case, the value of L is smaller after T steps of SGD with teleportation, compared to vanilla SGD (Proposition A.2). 3.2 TELEPORTATION AND NEWTON S METHOD Intuitively, teleportation can speed up optimization as it behaves similarly to Newton s method. After a teleportation that takes parameters to a critical point on a level set, the gradient descent direction is the same as the Newton direction (Zhao et al., 2022). As a result, we can leverage the convergence of Newton s method to derive the convergence rate of teleportation for the deterministic setting. Proposition 3.2 (Quadratic term in convergence rate). Let L be strictly convex and let w0 Rd. Let w arg max w Rd 1 2 L(w) 2, s.t. L(w) = L(w0). Let 2L be the Hessian of L, and λmax( 2L(w)) be the largest eigenvalue of 2L(w). If L(w ) = 0, then there exists λ0 such that 0 λ0 λmax( 2L(w0)), and one step of gradient descent after teleportation with learning rate γ > 0 gives w1 = w γ L(w ) = w γλ0 2L(w ) 1 L(w ). (3) Let w = g0 w0. If γ 1 λ0 , L is a µ strongly convex L smooth function, and the Hessian is G Lipschitz, then we have that 2µ g0 w0 w 2 + |1 γλ0| L 2µ g0 w0 w . More details about the assumptions and the proof are in Appendix B. Note that due to unknown step size λ0, extra care is needed in establishing this convergence rate. The above proposition shows that taking one step of teleportation and one gradient step, the result is equal to taking a dampened Newton step (equation 3). Hence, the convergence rate has a quadratically contracting term g0 w0 w 2, which is typical of second order methods. In particular, setting γ = 1/λ0 we would have local quadratic convergence. In contrast, without the teleportation step and under the same assumptions, we would have the following linear convergence w1 w (1 µγ) w0 w L using gradient descent. Thus there would be no quadratically contracting term. 3.3 WHEN IS ONE TELEPORTATION ENOUGH Despite the guaranteed improvement in convergence, teleporting before every gradient descent step is computationally expensive. Hence we teleport only occasionally. In fact, for certain optimization objectives, every point on the gradient flow has the largest gradient norm in its loss level set after one teleportation (Zhao et al., 2022). In past work, this result is limited to convex quadratic functions. In this section, we give a sufficient condition for when one teleportation results in an optimal trajectory for general loss functions. Full proofs can be found in Appendix C. Let V : M TM be a vector field on the manifold M, where TM denotes the associated tangent bundle. Here we consider the parameter space M = Rn, although results in this section can be extended to optimization on other manifolds. In this case, we may write V = vi wi using the component functions vi : Rn R and coordinates wi. Consider a smooth loss function L : M R. Let G be a symmetry group of L, i.e. L(g w) = L(w) for all w M and g G. Let X be the set of all vector fields on M. Let R = ri wi , where ri = L wi , be the reverse gradient vector field. Let X = {A = ai wi X| ai C (M) and P i ai(w)ri(w) = 0, w M} be the set of vector fields orthogonal to R. If G is a Lie group, the infinitesimal action of its Lie algebra g defines a set of vector fields Xg X . Published as a conference paper at ICLR 2024 A gradient flow is a curve γ : R M where the velocity is given by the value of R, i.e. γ (t) = Rγ(t) for all t R. The Lie bracket [A, R] defines the derivative of R with respect to A. Flows of A and R commute if and only if [A, R] = 0 (Theorem 9.44, Lee (2013)). That is, teleportation can affect the convergence rate only if [A, R]L = 0 for some A Xg. To simplify notation, we write ([W, R]L)(w) = 0 for a set of vector fields W X when ([A, R]L)(w) = 0 for all A W. We consider a gradient flow optimal if every point on the flow is a critical point of the magnitude of gradient in its loss level set. Note that this definition does not exclude the case where points on the flow are minimizers of the magnitude of gradient. Definition 3.3. Let f : M R, w 7 L w 2 2. A point w M is optimal with respect to a set of vector fields W X if Af(w) = 0 for all A W. A gradient flow γ : R M is optimal with respect to W if γ(t) is optimal with respect to W for all t R. Proposition 3.4. A point w M is optimal with respect to a set of vector fields W if and only if ([W, R]L)(w) = 0. A sufficient condition for one teleportation to result in an optimal trajectory is that whenever the function [A, R]L vanishes at w M, it vanishes along the entire gradient flow starting at w. Proposition 3.5. Let W X be a set of vector fields that are orthogonal to L w. Assume that for all w M such that ([W, R]L)(w) = 0, we have that (R[W, R]L)(w) = 0. Then the gradient flow starting at any optimal point with respect to W is optimal with respect to W. To help check when the assumption in Proposition 3.5 is satisfied, we provide an alternative form of R[W, R]L(w) when [W, R]L(w) = 0. Proposition 3.6. If at all optimal points in S = {(M L w)i wi X| M Rn n, M T = M} , 3L wk wi wj L wi = 0 for all anti-symmetric matrices M Rn n, then the gradient flow starting at an optimal point in S is optimal in S. From Proposition 3.6, we see that R[W, R]L(w) is not automatically 0 when [W, R]L(w) = 0. Therefore, even if the group is big enough to have its infinitesimal actions cover the tangent space of the level set (Xg = X ), one teleportation does not guarantee that the gradient flow intersects all future level sets at optimal points. However, for loss functions that satisfy the condition in Proposition 3.5, teleporting once optimizes the entire trajectory. This is the case, for example, when 3L wk wi wj L wα = 3L wk wi wα L wj for all i, k, j, α (Proposition C.3). In particular, all quadratic functions meet this condition. 4 TELEPORTATION FOR IMPROVING GENERALIZATION Teleportation was originally proposed to speedup optimization. In this section, we explore the suitability of teleportation for improving generalization, which is another important aspect of deep learning. We first review definitions of the sharpness of minima. Then, we introduce a novel notion of the curvature of minima and discuss its implications on generalization. By observing how sharpness and curvature of minima are correlated with generalization, we improve generalization by incorporating sharpness and curvature into the objective for teleportation. 4.1 SHARPNESS OF MINIMA Flat minima tend to generalize well (Hochreiter & Schmidhuber, 1997), typically characterized by numerous small Hessian eigenvalues. Although Hessian-based sharpness metrics are known to correlate well with generalization, they are expensive to compute and differentiate through. To use sharpness as an objective in teleportation, we consider changes in the loss averaged over random directions. Let D be a set of vectors drawn randomly from the unit sphere di {d Rn : ||d|| = 1}, and T a list of displacements tj R. Then, we have the following metric (Izmailov et al., 2018): Sharpness: ϕ(w, T, D) = 1 |T||D| d D L(w + td). (4) Published as a conference paper at ICLR 2024 4.2 CURVATURE OF MINIMA Figure 2: Gradient flow (L(w)) and a curve on the minimum (γ). The curvature of both curves may affect generalization. At a minimum, the loss-invariant or flat directions are zero eigenvectors of the Hessian. The curvature along these directions does not directly affect Hessian-based sharpness metrics. However, these curvatures may affect generalization, by themselves or by correlating to the curvature along non-flat directions. Unlike the curvature of the loss (curve L(w) in Figure 2), the curvature of the minima (curve γ) is less well studied. We provide a novel method to quantify the curvature of the minima below. Assume that the loss function L has a G symmetry. Consider the curve γM : R Rn Rn where M Lie(G) and γM(t, w) = exp (t M) w. Then γ(0, w) = w, and every point on γM is in the minimum if w is a minimum. Let γ = dγ dt be the derivative of a curve γ. The curva- ture of γ is κ(γ, t) = T (t) γ (t) , where T(t) = γ (t) γ (t) is the unit tangent vector. We assume that the action map is smooth, since calculating the curvature requires second derivatives and optimizing the curvature via gradient descent requires third derivatives. For multilayer network with element-wise activations, we derive the group action, γ, and κ in Appendix D. Since the minimum can have more than one dimension, we measure the curvature of a point w on the minimum by averaging the curvature of k curves with randomly selected Lie algebra elements Mi Lie(G). The resulting new metric is Curvature: ψ(w, k) = 1 i=1 κ(γMi(0, w), 0) . (5) There are different ways to measure the curvature of a higher-dimensional manifold, such as using the Gaussian curvature of 2D subspaces of the tangent space. However, our method of approximating the mean curvature is easier to compute and suitable as a differentiable objective. (a) (b) (c) (d) !"#$(& , (" #) !"#$(& , (" Figure 3: Illustration of the effect of sharpness (a,b) and curvature (c,d) of minima on generalization. See Figure 2 for a 3D visualization of the curves L(w) and γ. When the loss landscape shifts due to a change in data distribution, sharper minima have larger increase in loss. In the example shown, minima with larger curvature moves further away from the shifted minima. 4.3 CORRELATION WITH GENERALIZATION Generalization reflects how loss changes with shifts in data distribution. The sharpness of minima is well known to be correlated with generalization. Figure 3(a)(b) visualizes an example of the shift in loss landscape (L(w)), and the change of loss L at a minimizer w is large when the minimum is sharp. The relation between the curvature of minimum and generalization is less well studied. Figure 3(c)(d) shows one possible shift of the minimum (γ). Under this shifting, the minimizer with a larger curvature becomes farther away from the shifted minimum. The curve on the minimum can shift in other directions. Appendix E.2 provides analytical examples of the correlation between curvature and expected distance between the old and shifted minimum. Published as a conference paper at ICLR 2024 Table 1: Correlation with validation loss sharpness (ϕ) curvature (ψ) MNIST Fashion-MNIST CIFAR-10 MNIST Fashion-MNIST CIFAR-10 0.704 0.790 0.899 -0.050 -0.232 -0.167 We verify the correlation between sharpness, curvatures, and validation loss on MNIST (Deng, 2012), Fashion-MNIST (Xiao et al., 2017), and CIFAR-10 (Krizhevsky et al., 2009). On each dataset, we train 100 three-layer neural networks with Leaky Re LU using different initializations. Details of the setup can be found in Appendix E.3. Table 1 shows the Pearson correlation between validation loss and sharpness or curvature (scatter plots in Figure 9 and 10 in the appendix). In all three datasets, sharpness has a strong positive correlation with validation loss, meaning that the average change in loss under perturbations is a good indicator of test performance. For the architecture we consider, the curvature of minima is negatively correlated with the validation loss. We observe that the magnitudes of the curvatures are small, which suggests that the minima are relatively flat. 4.4 TELEPORTATION FOR IMPROVING GENERALIZATION To improve the generalization ability of the minimizer and to gain understanding of the curvature of minima, we teleport parameters to regions with different sharpness and curvature. Multi-layer neural networks have GL(R) symmetry between layers (Appendix D.1). We parametrize the group by its Lie algebra T, and perform gradient ascent on T to maximize the gradient norm at the transformed parameters | L|exp (T ) w|. Algorithm 2 in Appendix E.4 demonstrates how to increase curvature ψ by teleporting two layers, with hidden dimension h, in an MLP. In experiments, we use an extended version of the algorithm, which teleports all layers by optimizing on a list of T s concurrently. During teleportation, we perform gradient descent on the group elements to change ϕ or ψ. Results are averaged over 5 runs. 0 20 40 Epoch SGD teleport(decrease ) teleport(increase ) 0 20 40 Epoch SGD teleport(decrease ) teleport(increase ) Figure 4: Changing sharpness (left) or curvature (right) using teleportation and its effect on generalization on CIFAR-10. Solid line represents average test loss, and dashed line represent average training loss. Teleporting to decrease sharpness improves validation loss slightly. Teleportation changing curvatures has a more significant impact on generalization ability. Figure 4 shows the training curve of SGD on CIFAR-10, with one teleportation at epoch 20. Similar results for Ada Grad can be found in Appendix E.4. Teleporting to flatter points slightly improves the validation loss, while teleporting to sharper points has no effect. Since the group action keeps the loss invariant only on the batch of data used in teleportation, the errors incurred in teleportation have a similar effect to a warm restart, which makes the effect of changing sharpness less clear. Interestingly, by changing the curvature, teleportation is able to affect generalization. Teleporting to points with larger curvatures helps find a minimum with lower validation loss, while teleporting to points with smaller curvatures has the opposite effect. This suggests that at least locally, curvature is correlated with generalization. Details of the experiment setup can be found in Appendix E.4. Published as a conference paper at ICLR 2024 5 APPLICATIONS TO OTHER OPTIMIZATION ALGORITHMS Having shown teleportation s potential to improve optimization and generalization, we demonstrate its wide applicability by integrating teleportation into different optimizers and meta-learning. 5.1 STANDARD OPTIMIZERS Teleportation improves optimization not only for SGD. To show that teleportation works well with other standard optimizers, we train a 3-layer neural network on MNIST using different optimizers with and without teleportation. During training, we teleport once at the first epoch, using 8 minibatches of size 200. Details can be found in Appendix F.2. Figure 5 shows that teleportation improves the convergence rate when using Ada Grad, SGD with momentum, RMSProp, and Adam. The runtime for a teleportation is smaller than the time required to train one epoch, hence teleportation improves convergence rate per epoch at almost no additional cost of time (Figure 13 in the appendix). 0 10 20 30 40 Epoch Adagrad train Adagrad test Adagrad+teleport train Adagrad_teleport test 0 10 20 30 40 Epoch momentum train momentum test momentum+teleport train momentum_teleport test 0 10 20 30 40 Epoch RMSprop train RMSprop test RMSprop+teleport train RMSprop_teleport test 0 10 20 30 40 Epoch Adam train Adam test Adam+teleport train Adam_teleport test Figure 5: Integrating teleportation with Ada Grad, momentum, RMSProp, and Adam improves the convergence rate on MNIST. Solid line represents the average test loss, and dashed line represents the average training loss. Shaded areas are 1 standard deviation of the test loss across 5 runs. 5.2 LEARNING TO TELEPORT In optimization-based meta-learning, the parameter update rule or the hyperparameters are learned using a meta-optimizer (Andrychowicz et al., 2016; Finn et al., 2017). Teleportation introduces an additional degree of freedom in parameter updates. We augment existing meta-learning algorithms by learning both the local update and teleportation. This allows us to teleport without implementing the additional optimization step on groups, which reduces computation time. Let wt Rd be the parameters at time t, and t = L w wt be the gradient of the loss L. In gradient descent, the update rule with learning rate η is wt+1 = wt η t. In meta-learning (Andrychowicz et al., 2016), the update on wt is learned using a meta-learning optimizer m, which takes t as input. Here m is an LSTM model. Denote ht as the hidden state in the LSTM and ϕ as the parameters in m. The update rule is wt+1 = wt + ft ft ht+1 = m( t, ht, ϕ). Extending this approach beyond an additive update rule, we learn to teleport. Let G be a group whose action on the parameter space leaves L invariant. We use two meta-learning optimizers m1, m2 to learn the update direction ft Rd and the group element gt G: wt+1 = gt (wt + ft) ft h1t+1 = m1( t, h1t, ϕ1), gt h2t+1 = m2( t, h2t, ϕ2). Published as a conference paper at ICLR 2024 Experiment setup. We train and test on two-layer neural networks L(W1, W2) = Y W2σ(W1X) 2, where W2, W1, X, Y R20 20, and σ is the Leaky Re LU function with slope coefficient 0.1. Both meta-optimizers are two-layer LSTMs with hidden dimension 300. We train the meta-optimizers on multiple trajectories created with different initializations, each consisting of 100 steps of gradient descent on L with random X, Y and randomly initialized W s. We update the parameters in m1 and m2 by unrolling every 10 steps. The learning rate for meta-optimizers are 10 4 for m1 and 10 3 for m2. We test the meta-optimizers using 5 trajectories not seen in training. Algorithm 1 summarizes the training procedure. The vanilla gradient descent baseline ( GD ) uses the largest learning rate that does not lead to divergence (3 10 4). The second baseline ( LSTM(update) ) learns the update ft only and does not perform teleportation (gt = I, t). The third baseline ( LSTM(lr,tele) ) learns the group element gt and the learning rate used to perform gradient descent instead of the update ft. We keep training until adding more training trajectories does not improve convergence rate. We use 700 training trajectories for our approach, 600 for the second baseline, and 30 for the third baseline. Results. By learning both the local update ft and non-local transformation gt, our meta-optimizer successfully learns to learn faster. Figure 6 shows the improvement of our approach from the previous meta-learning method, which only learns ft. Compared to the baselines, learning the two types of updates together ( LSTM(update,tele) ) achieves better convergence rate than learning them separately. Additionally, learning the group element gt eliminates the need for performing gradient ascent on the group manifold and reduces hyperparameter tuning for teleportation. As an example of successful integration of teleportation into existing optimization algorithms, this toy experiment demonstrates the flexibility and promising applications of teleportation. Algorithm 1 Learning to teleport Input: Loss function L, learning rate η, number of epochs T, LSTM models m1, m2 with initial parameters ϕ1, ϕ2, unroll step tunroll. Output: Trained parameters ϕ1 and ϕ2. for each training initialization do for t = 1 to T do ft, h1t+1 = m1( t, h1t, ϕ1) gt, h2t+1 = m2( t, h2t, ϕ2) w gt (w + ft) if t mod tunroll = 0 then update ϕ1, ϕ2 by back-propogation from the accumulated loss Pt i=t tunroll L(wi) end if end for end for 0 10 20 30 Epoch GD LSTM(lr,tele) LSTM(update) LSTM(update,tele) Figure 6: Performance of the trained meta-optimizer on the test set. Learning both local update ft and nonlocal transformation gt results in better convergence rate than learning only local updates or learning only teleportation. 6 DISCUSSION Teleportation is a powerful tool to search in the loss level sets for parameters with desired properties. We provide theoretical guarantees that teleportation accelerates the convergence rate of SGD. Using concepts in symmetry, we propose a novel notion of curvature and show that incorporating additional teleportation objectives such as changing the curvatures can be beneficial to generalization. The close relationship between symmetry and optimization opens up a number of exciting opportunities. Exploring other objectives in teleportation appears to be an interesting future direction. Other possible applications include extending teleportation to different architectures, such as convolutional or graph neural networks, and to different algorithms, such as sampling-based optimization. The empirical results linking sharpness and curvatures to generalization are intriguing. However, the theoretical origin of their relation remains unclear. In particular, a precise description of how the loss landscape changes under distribution shifts is not known. More investigation of the correlation between curvatures and generalization will help teleportation to further improve generalization and take us a step closer to understanding the loss landscape. Published as a conference paper at ICLR 2024 ACKNOWLEDGEMENTS This work was supported in part by Army-ECASE award W911NF-23-1-0231, the U.S. Department Of Energy, Office of Science under #DE-SC0022255, IARPA HAYSTAC Program, CDC-RFA-FT23-0069, NSF Grants #2205093, #2146343, #2134274, #2107256, and #2134178. Osmar Al essio. Formulas for second curvature, third curvature, normal curvature, first geodesic curvature and first geodesic torsion of implicit curve in n-dimensions. Computer Aided Geometric Design, 29(3-4):189 201, 2012. Maksym Andriushchenko, Francesco Croce, Maximilian M uller, Matthias Hein, and Nicolas Flammarion. A modern look at the relationship between sharpness and generalization. International Conference on Machine Learning, 2023. Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul, Brendan Shillingford, and Nando De Freitas. Learning to learn by gradient descent by gradient descent. Advances in Neural Information Processing Systems, 29, 2016. Marco Armenta and Pierre-Marc Jodoin. The representation theory of neural networks. Mathematics, 9(24), 2021. ISSN 2227-7390. Marco Armenta, Thierry Judge, Nathan Painchaud, Youssef Skandarani, Carl Lemaire, Gabriel Gibeau Sanchez, Philippe Spino, and Pierre-Marc Jodoin. Neural teleportation. Mathematics, 11(2):480, 2023. Vijay Badrinarayanan, Bamdev Mishra, and Roberto Cipolla. Symmetry-invariant optimization in deep networks. ar Xiv preprint ar Xiv:1511.01754, 2015. Xavier Bouthillier, Pierre Delaunay, Mirko Bronzi, Assya Trofimov, Brennan Nichyporuk, Justin Szeto, Nazanin Mohammadi Sepahvand, Edward Raff, Kanika Madan, Vikram Voleti, et al. Accounting for variance in machine learning benchmarks. Proceedings of Machine Learning and Systems, 3:747 769, 2021. Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann Le Cun, Carlo Baldassi, Christian Borgs, Jennifer Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-sgd: Biasing gradient descent into wide valleys. International Conference on Learning Representations, 2017. Li Deng. The mnist database of handwritten digit images for machine learning research [best of the web]. IEEE signal processing magazine, 29(6):141 142, 2012. Lijun Ding, Dmitriy Drusvyatskiy, Maryam Fazel, and Zaid Harchaoui. Flat minima generalize for low-rank matrix recovery. ar Xiv preprint ar Xiv:2203.03756, 2022. Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. In International Conference on Machine Learning, pp. 1019 1028. PMLR, 2017. Jesse Dodge, Gabriel Ilharco, Roy Schwartz, Ali Farhadi, Hannaneh Hajishirzi, and Noah Smith. Fine-tuning pretrained language models: Weight initializations, data orders, and early stopping. ar Xiv preprint ar Xiv:2002.06305, 2020. Simon S Du, Wei Hu, and Jason D Lee. Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced. Neural Information Processing Systems, 2018. Rahim Entezari, Hanie Sedghi, Olga Saukh, and Behnam Neyshabur. The role of permutation invariance in linear mode connectivity of neural networks. International Conference on Learning Representations, 2022. Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pp. 1126 1135. PMLR, 2017. Published as a conference paper at ICLR 2024 Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. Iordan Ganev and Robin Walters. Quiver neural networks. ar Xiv preprint ar Xiv:2207.12773, 2022. Iordan Ganev, Twan van Laarhoven, and Robin Walters. Universal approximation and model compression for radial neural networks. ar Xiv preprint ar Xiv:2107.02550v2, 2022. J Elisenda Grigsby, Kathryn Lindsey, Robert Meyerhoff, and Chenxi Wu. Functional dimension of feedforward relu neural networks. ar Xiv preprint ar Xiv:2209.04036, 2022. Sepp Hochreiter and J urgen Schmidhuber. Flat minima. Neural computation, 9(1):1 42, 1997. Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. Conference on Uncertainty in Artificial Intelligence, 2018. Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. International Conference on Learning Representations, 2017. Minyoung Kim, Da Li, Shell X Hu, and Timothy Hospedales. Fisher sam: Information geometry and sharpness aware minimisation. In International Conference on Machine Learning, pp. 11148 11161. PMLR, 2022. Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009. Daniel Kunin, Javier Sagastuy-Brena, Surya Ganguli, Daniel LK Yamins, and Hidenori Tanaka. Neural mechanics: Symmetry and broken conservation laws in deep learning dynamics. In International Conference on Learning Representations, 2021. Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpnessaware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning, pp. 5905 5914. PMLR, 2021. John M Lee. Introduction to Smooth Manifolds. Graduate Texts in Mathematics, vol 218. Springer, New York, NY, 2013. Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. International Conference on Learning Representations, 2017. Qi Meng, Shuxin Zheng, Huishuai Zhang, Wei Chen, Zhi-Ming Ma, and Tie-Yan Liu. G-SGD: Optimizing relu neural networks in its positively scale-invariant space. International Conference on Learning Representations, 2019. Behnam Neyshabur, Russ R Salakhutdinov, and Nati Srebro. Path-SGD: Path-normalized optimization in deep neural networks. In Advances in Neural Information Processing Systems, 2015. Henning Petzka, Michael Kamp, Linara Adilova, Cristian Sminchisescu, and Mario Boley. Relative flatness and generalization. 35th Conference on Neural Information Processing Systems, 2021. Sameera Ramasinghe, Lachlan Mac Donald, Moshiur Farazi, Hemanth Sartachandran, and Simon Lucey. How you start matters for generalization. ar Xiv preprint ar Xiv:2206.08558, 2022. Levent Sagun, Utku Evci, V Ugur Guney, Yann Dauphin, and Leon Bottou. Empirical analysis of the hessian of over-parametrized neural networks. ar Xiv preprint ar Xiv:1706.04454, 2017. Aleksandr Mikhailovich Shelekhov. On the curvatures of a curve in n-dimensional euclidean space. Russian Mathematics, 65(11):46 58, 2021. Berfin S ims ek, Franc ois Ged, Arthur Jacot, Francesco Spadaro, Cl ement Hongler, Wulfram Gerstner, and Johanni Brea. Geometry of the loss landscape in overparameterized neural networks: Symmetries and invariances. In International Conference on Machine Learning, pp. 9722 9732. PMLR, 2021. Published as a conference paper at ICLR 2024 Sebastian U. Stich. Unified optimal analysis of the (stochastic) gradient method. Co RR, 2019. Salma Tarmoun, Guilherme Franca, Benjamin D Haeffele, and Rene Vidal. Understanding the dynamics of gradient flow in overparameterized linear models. In International Conference on Machine Learning, pp. 10153 10161. PMLR, 2021. Twan Van Laarhoven. L2 regularization versus batch and weight normalization. Advances in Neural Information Processing Systems, 2017. Lei Wu, Zhanxing Zhu, et al. Towards understanding generalization of deep learning: Perspective of loss landscapes. ar Xiv preprint ar Xiv:1706.10239, 2017. Han Xiao, Kashif Rasul, and Roland Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. ar Xiv preprint ar Xiv:1708.07747, 2017. Bo Zhao, Nima Dehmamy, Robin Walters, and Rose Yu. Symmetry teleportation for accelerated optimization. Advances in Neural Information Processing Systems, 2022. Bo Zhao, Iordan Ganev, Robin Walters, Rose Yu, and Nima Dehmamy. Symmetries, flat minima, and the conserved quantities of gradient flow. International Conference on Learning Representations, 2023. Pan Zhou, Jiashi Feng, Chao Ma, Caiming Xiong, Steven Chu Hong Hoi, et al. Towards theoretically understanding why sgd generalizes better than adam in deep learning. Advances in Neural Information Processing Systems, 33:21285 21296, 2020. Published as a conference paper at ICLR 2024 This appendix contains proofs, experiment setups, as well as additional results and discussions. Appendix A through C contain proofs for theoretical results in Section 3. Appendix D provides details about curves induced by symmetry and the curvature of the minimum. Appendix E discusses possible theoretical approaches to relate curvatures and generalization. This section also contains experiment details on computing correlations and the algorithm that uses teleportation to change curvature. Appendix F describes experiment setups and different strategies of integrating teleportation into various optimization algorithms. The code used for our experiments is available at: https://github.com/Rose-STL-Lab/ Teleportation-Optimization. A TELEPORTATION AND SGD This section includes a proof for Theorem 3.1. Additionally, we discuss the theorem s implication when the loss function is strictly convex. Lemma A.1 (Descent Lemma). Let L(w, ξ) be a β smooth function. It follows that E L(w, ξ) 2 2β(L(w) L(w )) + 2β(L(w ) E h inf w L(w, ξ) i ). (6) Proof. Since L(w, ξ) is smooth we have that L(z, ξ) L(w, ξ) L(w, ξ), z w + β 2 z w 2, z, w Rd. (7) By inserting into equation 7 we have that L w (1/β) L(w, ξ), ξ L(w, ξ) 1 2β L(w, ξ) 2. (8) Re-arranging we have that L(w , ξ) L(w, ξ) = L(w , ξ) inf w L(w, ξ) + inf w L(w, ξ) L(w, ξ) L(w , ξ) inf w L(w, ξ) + L w (1/β) L(w, ξ), ξ L(w, ξ) equation 8 L(w , ξ) inf w L(w, ξ) 1 2β L(w, ξ) 2, where the first inequality follows because infw L(w, ξ) L(w, ξ), w. Re-arranging the above and taking expectation gives E L(w, ξ) 2 2E h β(L(w , ξ) inf w L(w, ξ) + L(w, ξ) L(w , ξ)) i 2βE h L(w , ξ) inf w L(w, ξ) + L(w, ξ) L(w , ξ) i 2β(L(w) L(w )) + 2β(L(w ) E h inf w L(w, ξ) i ). At each iteration t N+ in SGD, we choose a group element gt G and use teleportation before each gradient step as follows wt+1 = gt wt η L(gt wt, ξt). (9) Here η is a learning rate, L(wt, ξt) is a gradient of L(wt, ξt) with respect to the parameters w, and ξt D is a mini-batch of data sampled i.i.d at each iteration. Published as a conference paper at ICLR 2024 Theorem 3.1. Let L(w, ξ) be β smooth and let σ2 def = L(w ) E h inf w L(w, ξ) i . Consider the iterates wt given by equation 1 where gt arg max g G L(g wt) 2. (10) If η = 1 β T 1 then min t=0,...,T 1 E max g G L(g wt) 2 2β T 1E L(w0) L(w ) + βσ2 Proof. First note that if L(w, ξ) is β smooth, then L(w) is also a β smooth function, that is L(z) L(w) L(w), z w β 2 z w 2. (12) Using equation 1 with z = wt+1 and w = gt wt, together with equation 12 and the fact that the group action preserves loss, we have that L(wt+1) L(gt wt) + L(gt wt), wt+1 gt wt + β 2 wt+1 gt wt 2 (13) = L(wt) ηt L(gt wt), L(gt wt, ξt) + βη2 t 2 L(gt wt, ξt) 2. (14) Taking expectation conditioned on wt, we have that Et L(wt+1) L(wt) ηt L(gt wt) 2 + βη2 t 2 Et L(gt wt, ξt) 2 . (15) Now since L(w, ξ) is β smooth, from Lemma A.1 above we have that E L(w, ξ) 2 2β(L(w) L(w )) + 2β(L(w ) E h inf w L(w, ξ) i ) (16) Using equation 16 with w = gt wt we have that Et L(wt+1) L(wt) ηt L(gt wt) 2 + β2η2 t L(gt wt) L(w ) + L(w ) E h inf w L(w, ξ) i . (17) Using that L(gt wt) = L(wt), taking full expectation and re-arranging terms gives ηt E L(gt wt) 2 (1 + β2η2 t )E L(wt) L E L(wt+1) L + β2η2 t σ2. (18) Now we use a re-weighting trick introduced in Stich (2019). Let αt > 0 be a sequence such that αt(1 +β2η2 t ) = αt 1. Consequently if α 1 = 1 then αt = (1+ β2η2 t ) (t+1) . Multiplying by both sides of equation 18 by αt thus gives αtηt E L(gt wt) 2 αt 1E L(wt) L αt E L(wt+1) L + αtβ2η2 t σ2. (19) Summing up from t = 0, . . . , T 1, and using telescopic cancellation, gives t=0 αtηt E L(gt wt) 2 E L(w0) L + β2σ2 T 1 X t=0 αtη2 t (20) Let A = PT 1 t=0 αtηt. Dividing both sides by A gives min t=0,...,T 1 E L(gt wt) 2 1 PT 1 t=0 αtηt t=0 αtηt L(gt wt) 2 E L(w0) L + β2σ2 PT 1 t=0 αtη2 t PT 1 t=0 αtηt . (21) Published as a conference paper at ICLR 2024 Finally, if ηt η then t=0 αtηt = η t=0 (1 + β2η2 t ) (t+1) = η 1 + β2η2 1 (1 + β2η2) T 1 (1 + β2η2) 1 (22) = 1 (1 + β2η2) T To bound the term with the T power, we use that (1 + β2η2) T 1 2 = log(2) log(1 + β2η2) T. To simplify the above expression we can use x 1 + x log(1 + x) x, for x 1, thus log(2) log(1 + β2η2) 1 + β2η2 Using the above we have that t=0 αtηt 1 2β2η , for T 1 + β2η2 Using this lower bound in equation 21 gives min t=0,...,T 1 E L(gt wt) 2 2β2ηE L(w0) L + ηβ2σ2, for T 1 + β2η2 Now note that β2η2 β2η2(T 1) 1 η 1 Thus finally setting η = 1 β T 1 gives the result equation 2. Proposition A.2. Assume that L : Rn R is strictly convex and twice continuously differentiable. Assume also that for any two points wa, wb Rn such that L(wa) = L(wb), there exists a g G such that wa = g wb. At two points w1, w2 Rn, if maxg G L(g w1) 2 = L(w2) 2, then L(w1) L(w2). Proof. Let S(x) = {w : L(w) = x} be the level sets of L, and X = {L(w) : w Rn} be the image of L. Since G acts transitively on the level sets of L, maxg G L(g w) 2 = maxw S(x) L(w) 2. To simplify notation, we define a function F : X R, F(x) = maxw S(x) L(w) 2. Since L(w) is continuously differentiable, the directional derivative of F is defined. Additionally, since L is continuous and its domain Rn is connected, its image X is also connected. This means that for any w1, w2 Rn and min(L(w1), L(w2)) y max(L(w1), L(w2)), there exists a w3 Rn such that L(w3) = y. Next, we show that F( ) is strictly increasing by contradiction. Suppose that L(w1) < L(w2) and F(L(w1)) F(L(w2)). By the mean value theorem, there exists a w3 such that L(w1) < L(w3) < L(w2) and the directional derivative of F in the direction towards L(w2) is non-positive: L(w2) L(w3)F(L(w3)) 0. Let w 3 arg maxw S(L(w3)) L(w) 2 be a point that has the largest gradient norm in S(L(w3)). Then at w 3, L 2 cannot increase along the gradient direction. However, this means w L(w 3) 2 = L(w 3)T H L(w 3) 0. (24) Published as a conference paper at ICLR 2024 Since we assumed that L is convex and L(w 3) is not a minimum (L(w 3) > L(w1)), we have that L(w 3) = 0. Therefore, equation 24 contradicts with L being strictly convex, and we have F(L(w1)) < F(L(w2)). We have shown that L(w1) < L(w2) implies F(L(w1)) < F(L(w2)). Taking the contrapositive and switching w1 and w2, F(L(w1)) F(L(w2)) implies L(w1) L(w2). Equivalently, maxg G L(g w1) 2 maxg G L(g w2) 2 implies that L(w1) L(w2). Finally, since max g G L(g w1) 2 = L(w2) 2 max g G L(g w2) 2, (25) we have L(w1) L(w2). B TELEPORTATION AND NEWTON S METHOD Lemma B.1 (One step of Newton s Method). Let f(x) be a µ strongly convex and L smooth function, that is, we have a global lower bound on the Hessian given by LI 2f(x) µI, x Rn. (26) Furthermore, if the Hessian is also G Lipschitz 2f(x) 2f(y) G x y (27) then Newton s method xk+1 = xk λk 2f(xk) 1 f(xk) has a mixed linear and quadratic convergence according to 2µ xk x 2 + |1 λk| L 2µ xk x . (28) xk+1 x = xk x λk 2f(xk) 1 f(xk) f(x ) = xk x λk 2f(xk) 1 Z 1 s=0 2f(xk + s(x xk))(xk x )ds (Mean value theorem) = 2f(xk) 1 Z 1 2f(xk) λk 2f(xk + s(x xk)) (xk x )ds = 2f(xk) 1 Z 1 2f(xk) 2f(xk + s(x xk)) +(1 λk) 2f(xk + s(x xk)) (xk x )ds Let δk := xk+1 x . Taking norms we have that δk+1 2f(xk) 1 Z 1 2f(xk) 2f(xk + s(x xk)) +|1 λk 2f(xk + s(x xk)) δkds equation 27+equation 26 G s=0 s xk x 2ds + |1 λk|L s=0 s xk x ds 2µ xk x 2 + |1 λk| L The assumptions on for this proof can be relaxed, since we only require the Hessian is Lipschitz and lower bounded in a µ 2L ball around x . Published as a conference paper at ICLR 2024 Proposition 3.2 (Quadratic term in convergence rate). Let L be strictly convex and let w0 Rd. Let w arg max w Rd 1 2 L(w) 2 subject to L(w) = L(w0). (29) If L(w ) = 0 then there exists λ0 such that 0 λ0 λmax( 2L(w0)) and one step of gradient descent with learning rate γ > 0 gives w1 = w γ L(w ) = w γλ0 2L(w ) 1 L(w ). (30) Consequently, letting w = g0 w0, and if γ 1 λ0 then under the assumptions of Lemma B.1 we have that 2µ g0 w0 x 2 + |1 γλ0| L 2µ g0 w0 w . Proof. The Lagrangian associated to equation 29 is given by L(w, λ) = 1 2 L(w) 2 + λ(L(w0) L(w)). Taking the derivative in w and setting it to zero gives w L(w, λ0) = 0 = 2L(w) L(w) λ0 L(w) = 0. (31) Re-arranging we have that L(w) = λ0 2L(w) 1 L(w). If L(w ) = 0 then from the above we have that L(w) 2 = λ0 L(w) 2L(w) 1 L(w) > 0. Since 2L(w) 1 is positive definite we have that L(w) 2L(w) 1 L(w) 0, and consequently λ0 > 0. Finally from equation 31 we have that λ0 is an eigenvalue of 2L(w) and thus it must be smaller or equal to the largest eigenvalue of 2L(w). C IS ONE TELEPORTATION ENOUGH TO FIND THE OPTIMAL TRAJECTORY? This section contains proofs for the results in Section 3.3. For readability, we repeat some of the definitions here. Consider the parameter space M = Rn. Let V : Rn TRn be a vector field on Rn, where TRn denotes the associated tangent bundle. We will write V = vi wi using the component functions vi : Rn R and coordinates wi. Let L : M R be a smooth loss function. Let G be a symmetry group of L, i.e. L(g w) = L(w) for all w M and g G. Let X be the set of all vector fields on M. Let R = ri wi , where ri = L wi , be the reverse gradient vector field. Let X = {A = ai wi X| ai C (M) and P i ai(w)ri(w) = 0, w M} be the set of vector fields orthogonal to R. If G is a Lie group, the infinitesimal action of its Lie algebra g defines a set of vector fields Xg X . A gradient flow is a curve γ : R M where the velocity is the value of R at each point, i.e. γ (t) = Rγ(t) for all t R. The Lie bracket [A, R] defines the derivative of R with respect to A. To simplify notation, we write ([W, R]L)(w) = 0 for a set of vector fields W X when ([A, R]L)(w) = 0 for all A W. Proposition 3.4. A point w M is optimal in a set of vector fields W if and only if [A, R]L(w) = 0 for all A W. Published as a conference paper at ICLR 2024 Proof. Note that AL = ai L wi = 0. We have [A, R]L = ARL RAL = A ri L 2 = Af. (32) The result then follows from Definition 3.3. Proposition 3.5. Let W X be a set of vector fields that are orthogonal to the gradient of L. If [A, R]L(w) = 0 for all A W implies that R([A, R]L)(w) = 0 for all A W, then the gradient flow starting at an optimal point in W is optimal in W. Proof. Consider the gradient flow γ that starts at an optimal point in W. The derivative of [A, R]L along γ is d dt[A, R]L(γ(t)) = γ (t)([A, R]L)(γ(t)) = R[A, R]L(γ(t)). (33) Since γ(0) is an optimal point, [A, R]L(γ(0)) = 0 for all A W by Proposition 3.4. By assumption, if [A, R]L(γ(t)) = 0 for all A W, then R([A, R]L)(γ(t)) = 0 for all A W. Therefore, both the value and the derivative of [A, R]L stay 0 along γ. Since [A, R]L(γ(t)) = 0 for all t R, γ is optimal in W. To help check when Proposition 3.5 is satisfied, we provide an alternative form of R[A, R]L(w) under the assumption that [A, R]L(w) = 0. We will use the following lemmas in the proof. Lemma C.1. For two vectors v, w Rn, if v T w = 0 and w = 0, then there exists an antisymmetric matrix M Rn n such that v = Mw. Proof. Let w0 = [1, 0, ..., 0]T Rn. Consider a list of n 1 anti-symmetric matrices Mi Rn n, where 1, if j = 1 and k = i + 1 1, if j = i + 1 and k = 1 0, otherwise (34) In matrix form, the Mi s are 0 1 0 ... 0 1 0 0 ... 0 0 0 0 ... 0 ... 0 0 0 ... 0 0 0 1 ... 0 0 0 0 ... 0 1 0 0 ... 0 ... 0 0 0 ... 0 , ..., Mn 1 = 0 0 0 ... 1 0 0 0 ... 0 0 0 0 ... 0 ... 1 0 0 ... 0 Since Mi s are anti-symmetric, Miw0 is orthogonal to w0. The norm of Miw0 = ei+1 is 1. Additionally, Miw0 is orthogonal to Mjw0 for i = j: (Miw0)T (Mjw0) = e T i+1ej+1 = δij. (36) Denote w 0 = {x Rn : x T w0 = 0} as the orthogonal complement of w0. Then Miw0 forms a basis of w 0 . Next, we extend this to an arbitrary w Rn. Let ˆw = w w 2 . Since ˆw has norm 1, there exists an orthogonal matrix R such that ˆw = Rw0. Let M i = RMi RT . Then M i is anti-symmetric: (RMi RT )T = RM T i RT = RMi RT . (37) It follows that M i ˆw is orthogonal to ˆw. The norm of M i ˆw is (RMi RT )(Rw0) = RMiw0 = Miw0 = 1. Additionally, M i ˆw is orthogonal to M j ˆw for i = j: (M i ˆw)T (M j ˆw) = (RMi RT Rw0)T (RMj RT Rw0) = w T 0 RT RM T i RT RMj RT Rw0 = w T 0 M T i Mjw0 = δij. (38) Published as a conference paper at ICLR 2024 Therefore, M i ˆw spans ˆw = w . This means that any vector v w can be written as a linear combination of M i ˆw. That is, there exists k1, ..., kn R, such that v = P i ki(M i ˆw). To find the anti-symmetric M that takes w to v, note that Since the sum of anti-symmetric matrices is anti-symmetric, and the product of an anti-symmetric matrix and a scalar is also anti-symmetric, w 1 2 P i ki M i is anti-symmetric. Lemma C.2. Let v Rn be a nonzero vector. Then the two sets {Mv : M Rn n, M T = M} and {w Rn : w T v = 0} are equal. Proof. Let A = {Mv : M Rn n, M T = M 1} and B = {w Rn : w T v = 0}. Since (Mv)T v = 0 for all anti-symmetric M, every element in A is in B. By Lemma C.1, every element in B is in A. Therefore A = B. Let S = {(M L w)i wi X| M Rn n, M T = M} be the set of vector fields constructed by multiplying the gradient by an anti-symmetric matrix. Recall that R = L wi wi is the reverse gradient vector field, and X = {ai wi | P i ai(w) L(w) wi = 0, w M} is the set of all vector fields orthogonal to R. From Lemma C.2, we have S = X . Therefore, a point w is an optimal point in S if and only if w is an optimal point in X . We are now ready to prove the following proposition, which provides another way to check the condition in Proposition 3.5. Proposition 3.6. If at all optimal points in S, 3L wk wi wj L wi = 0 (40) for all anti-symmetric matrix M Rn n, then the gradient flow starting at an optimal point in S is optimal in S. Proof. Expanding R[A, R]L, we have R[A, R]L = R A ri L wj L wi + ri aj 2L wi wj L wi wk 2L wi wj L wi + aj wk 2L wi wj L wi wk 2L wi wj L wi + 2 L 2L wi wj L wi Assume that w is an optimal point in S. By Lemma C.2, w is also an optimal point in X . By Lemma C.4 in Zhao et al. (2022), L w is an eigenvector of 2L wi wj . Therefore, 2L wi wj L wj for some λ C. Additionally, aj = M j α L wα and aj wk = M j α 2L wα wk . We are now ready to simplify both terms in equation 41. Published as a conference paper at ICLR 2024 For the first term in equation 41, wk 2L wi wj L wi = L wk M j α 2L wα wk 2L wi wj L wi 2L wα wk L wk 2L wi wj L wi = λ1λ2M j α L wα The last equality holds because M is anti-symmetric. For the second term in equation 41, 2L wi wj L wi wk aj 3L wk wi wj L wi + 2L wi wj 2L wk wi wk M j α L wα 3L wk wi wj L wi + 2L wi wj 2L wk wi = M j α L wk 3L wk wi wj L wi + λ1λ2M j α L wα = M j α L wk 3L wk wi wj L wi (43) In summary, R[A, R]L = 2M j α L wk 3L wk wi wj L wi . (44) Since we assumed that [A, R]L(w) = 0, when R[A, R]L(w) = 0 for all A S, the gradient flow starting at an optimal point in S is optimal in S. Proposition C.3. If 3L wk wi wj L wα = 3L wk wi wα L wj holds for all i, k, j, α, then M j α L wk L wα 3L wk wi wj L wi = 0 holds for all anti-symmetric matrices M Rn n. Proof. If 3L wk wi wj L wα = 3L wk wi wα L wj for all i, k, j, α, then 3L wk wi wj L wi i,k,αj M j α L wk 3L wk wi wj L wi i,k,αα M α j L wk 3L wk wi wα L wi i,k,αα M j α L wk 3L wk wi wα L wi i,k,α0: ϕ1(w, ε) = |{λi(H)(w) : λi > ε}| . (61) A related sharpness metric uses the logarithm of the product of the k largest eigenvalues (Wu et al., 2017), i=1 log λi(H)(w). (62) Both metrics require computing the eigenvalues of the Hessian. As a result, optimizing on these metrics during teleportation is prohibitively expensive. Hence, in this paper we use the average change in loss averaged over random directions (ϕ) as objective in generalization experiments. E.2 MORE INTUITION ON CURVATURES AND GENERALIZATION E.2.1 EXAMPLE: CURVATURE AFFECTS AVERAGE DISPLACEMENT OF MINIMA Consider an optimization problem with two variables w1, w2 R. Assume that the minimum is a curve γ : R R2 in the two-dimensional parameter space. For a point w0 on γ, we estimate its generalization ability by computing the expected distance between w0 and the new minimum obtained by shifting γ. We consider the following two curves as examples: γ1 :R R2, t 7 (t, k1t2) γ2 :[0, 2π] R2, θ 7 (k2 cos(θ), k2 sin(θ) + k2), (63) with k1, k2 R =0. The curve γ1 is a parabola with curvature κ1 = 2k1 at w0 = (0, 0). The curve γ2 is a circle, with curvature κ2 = 1 k2 at w0. Note that γ1 is the only polynomial approximation with integer power (γ(t) = (t, k|t|n), n Z+) where the curvature at w0 depends on k. When n < 1, the value of w0 is undefined. When n = 1, the first derivative at w0 is undefined. When n > 2, κ(w0) = 0. Assume that a distribution shift in data causes γ to shift by a distance r, and that the direction of the shift is chosen uniformly at random over all possible directions. Viewing from the perspective of the curve, this is equivalent to shifting w0 by distance r. The distance between a point w and a curve γ is dist(w, γ) = min w γ2 w w 2. (64) Let Sr be the circle centered at the origin with radius r. The expected distance between the old solution w0 and shifted curve is Ew Sr[dist(w, γ)] = Sr dist(w, γ)ds R R 2π 0 dist((r cos θ, r sin θ), γ)rdθ R 2π 0 rdθ . (65) In the limit of zero curvature, γ is a straight line γ(t) = (t, 0). In this case, the expected distance is Ew Sr[dist(w, γ)] = R 2π 0 |r sin θ|rdθ π 0.637r. (66) Published as a conference paper at ICLR 2024 Figure 7(b)(c) shows that the expected distance s dependence on κ. Using both curves γ1 and γ2, the generalization ability of w0 depends on the curvature at w0. However, the type of dependence is affected by the type of curve used. In other words, the curvatures at points around w0 affect how the curvature at w0 affects generalization. Therefore, from these results alone, one cannot deduce whether minima with sharper curvatures generalize better or worse. To find a more definitive relationship between curvature and generalization, further investigation on the type of curves on the minimum is required. We emphasize that this example only serves as an intuition for connecting curvature to generalization. As a future direction, it would be interesting to consider different families of parametric curves, higher dimensional parameter spaces, and deforming in addition to shifting the minima. (a) (b) (c) 𝑤! 0 20 40 60 80 = 2k1 r Srdist(w, 1) r=2.0 r=1.0 r=0.5 r=0.2 r=0.1 0 2 4 6 8 10 = k 1 2 r Srdist(w, 2) r=0.1 r=0.05 r=0.02 r=0.01 r=0.005 Figure 7: (a) Illustration of the parameter space, the minimum (γ), and all shifts with distance r (Sr). (b) Expected distance between w0 and the new minimum as a function of κ, for quadratic approximation γ1. (c) Expected distance between w0 and the new minimum as a function of κ, for constant curvature approximation γ2. The expected distance is scaled by r so that the curves can be plotted together. E.2.2 HIGHER DIMENSIONS Figure 8 visualizes a curve obtained from a 2D minimum. However, it is not immediately clear what curves look like on a higher-dimensional minimum. A possible way to extend previous analysis is to consider sectional curvatures. Figure 8: Left: a 2D minima in a 3D parameter space. Right: a 2D subspace of the parameter space and a curve on the minima (the intersection of the minima and the subspace). E.3 COMPUTING CORRELATION TO GENERALIZATION We generate the 100 different models used in Section 4.3 by training randomly initialized models. For all three datasets (MNIST, Fashion MNIST, and CIFAR-10), we train on 50,000 samples and test on a different set of 10,000 samples. The labels for classification tasks belongs to 1 of 10 classes. For a batch of flattened input data X Rd 20 and labels Y R20, the loss function is L(W1, W2, W3, X, Y ) = Cross Entropy (W3σ(W2σ(W1X)), Y ), where W3 R10 h2, W2 Rh2 h1, W1 Rh1 d are the weight matrices, and σ is the Leaky Re LU activation with slope coefficient 0.1. For MNIST and Fashion-MNIST, d = 282, h1 = 16, and h2 = 10. For CIFAR-10, Published as a conference paper at ICLR 2024 d = 323 3, h1 = 128, and h2 = 32. The learning rate for stochastic gradient descent is 0.01 for MNIST and Fashion-MNIST, and 0.02 for CIFAR-10. We train each model using mini-batches of size 20 for 40 epochs. When computing the sharpness ϕ, we choose the displacement list T that gives the highest correlation. The displacements used in this paper are T = 0.001, 0.011, 0.021, ..., 0.191 for MNIST, and T = 0.001, 0.011, 0.021, ..., 0.191 for Fashion-MNIST and CIFAR-10. We evaluate the change in loss over |D| = 200 random directions. For curvature ψ, we average over k = 1 curves generated by random Lie algebras (invertible matrices in this case). Figure 9 and 10 visualizes the correlation result in Table 1. Each point represents one model. (a) (b) (c) 0.0005 0.0006 0.0007 0.15 validation loss 0.00144 0.00153 0.00162 validation loss 0.0057 0.0060 0.0063 1.48 1.50 1.52 1.54 1.56 1.58 1.60 1.62 1.64 validation loss Figure 9: Correlation between sharpness and validation loss on MNIST (left), Fashion-MNIST (middle), and CIFAR-10 (right). Sharpness and generalization are strongly correlated. (a) (b) (c) 0.001 0.002 0.003 0.15 validation loss Corr=-0.050 0.001 0.003 0.005 validation loss Corr=-0.232 0.0003 0.0006 0.0009 1.48 1.50 1.52 1.54 1.56 1.58 1.60 1.62 1.64 validation loss Corr=-0.167 Figure 10: Correlation between curvature and validation loss on MNIST (left), Fashion-MNIST (middle), and CIFAR-10 (right). There is a weak negative correlation in all three datasets. E.4 ADDITIONAL DETAILS FOR GENERALIZATION EXPERIMENTS Algorithm 2 shows an example on how to perform a teleportation with an MLP. Algorithm 2 Changing curvature using teleportation Input: loss function L(w), parameters before teleportation w0, teleportation learning rate ηteleport, number of teleportation steps nteleport. Output: parameters after teleportation wnteleport. for t = 0 to nteleport 1 do initialize T = 0h h set w t = (Ih h + T) wt compute grad = d|ψ(w t)| d T set Tt = ηteleport grad set wt+1 = (I + Tt) wt end for Return wnteleport Published as a conference paper at ICLR 2024 On CIFAR-10, we run SGD using the same three-layer architecture as in Section E.3, but with a smaller hidden size h1 = 32 and h2 = 10. At epoch 20 which is close to convergence, we teleport using 5 batches of data, each of size 2000. During each teleportation for ϕ, we perform 10 gradient ascent (or descent) steps on the group element. During each teleportation for ψ, we perform 1 gradient ascent (or descent) step on the group element. The learning rate for the optimization on group elements is 5 10 2. To investigate how teleportation affects generalization for other optimizers, we repeat the same experiment but replace SGD with Ada Grad. Figure 11 shows the training curve of Ada Grad on CIFAR-10, averaged across 5 runs. Similar to SGD, changing curvature via teleportation affects the validation loss, while changing sharpness has negligible effects. Teleporting to points with larger curvatures helps find minima with slightly lower validation loss. Teleporting to points with smaller curvatures increases the gap between training and validation loss. 0 20 40 Epoch Ada Grad teleport(decrease ) teleport(increase ) 0 20 40 Epoch Ada Grad teleport(decrease ) teleport(increase ) Figure 11: Changing sharpness (left) or curvature (right) using teleportation and its effect on generalizability of Ada Grad solutions on CIFAR-10. Solid line represents average test loss, and dashed line represent average training loss. F INTEGRATING TELEPORTATION WITH OTHER GRADIENT-BASED ALGORITHMS F.1 DIFFERENT METHODS OF INTEGRATING TELEPORTATION WITH MOMENTUM AND ADAGRAD Setup. We test teleportation with various algorithms using the a 3-layer neural network and mean square error: min W1,W2,W3 Y W3σ(W2σ(W1X)) 2, with data X R5 4, target Y R8 4, and weight matrices W3 R8 7, W2 R7 6, and W1 R6 5. The activation function σ is Leaky Re LU with slope coefficient 0.1. Each element in the weight matrices is initialized uniformly at random over [0, 1]. Data X, Y are randomly generated also from [0, 1]. Momentum. We compare three strategies of integrating teleportation with momentum: teleporting both parameters and momentum, teleporting parameters but not momentum, and reset momentum to 0 after a teleportation. In each run, we teleport once at epoch 5. Each strategy is repeated 5 times. The training curves of teleporting momentum in different ways are similar (Figure 12a), possibly because the momentum accumulated is small compared to the gradient right after teleportations. All methods of teleporting momentum improves convergence, which means teleportation works well with momentum. Ada Grad. In Ada Grad, the rate of change in loss is dt = η L A, (67) where η R is the learning rate, and L A is the Mahalanobis norm with A = (εI + diag(Gt+1)) 1 2 . Previously, we optimize L 2 in teleportation. We compare that to optimizing L A. Since the magnitude of A is different than 1, a different learning rate for the gradient Published as a conference paper at ICLR 2024 ascent in teleportation is required. We choose the largest learning rate (with two significant figures) that does not lead to divergence. The teleportation learning rates used are 1.2 10 5 for objective maxg L 2 and 7.5 10 3 for objective maxg L A. Teleporting using the group element that optimizes L A has a slight advantage (Figure 12b). Similar to the observations in Zhao et al. (2022), teleportation can be integrated into adaptive gradient descents. 0 100 200 300 Epoch Momentum Momentum+teleport(none) Momentum+teleport(teleport) Momentum+teleport(reset) 0 100 200 300 Epoch Ada Grad Ada Grad+teleport_||d L/dw||_2 Ada Grad+teleport_||d L/dw||_A Figure 12: Comparison of different methods of integrating teleportation with momentum and Ada Grad. F.2 ADDITIONAL DETAILS FOR EXPERIMENTS ON MNIST We use a three-layer model and cross-entropy loss for classification with minibatches of size 20. For a batch of flattened input data X R282 20 and labels Y R20, the loss function is L(W1, W2, W3, X, Y ) = Cross Entropy (W3σ(W2σ(W1X)), Y ), where W3 R10 10, W2 R10 16, W1 R16 282 are the weight matrices, and σ is the Leaky Re LU activation with slope coefficient 0.1. The learning rates are 10 4 for Ada Grad, and 5 10 2 for SGD with momentum, RMSProp, and Adam. The learning rate for optimizing the group element in teleportation is 5 10 2, and we perform 10 gradient ascent steps when teleporting using each mini-batch. We use 50,000 samples from training set for training, and 10,000 samples in the test set for testing. (a) (b) (c) (d) 0 10 20 30 40 50 60 70 Adagrad train Adagrad test Adagrad+teleport train Adagrad+teleport test 0 10 20 30 40 50 60 70 momentum train momentum test momentum+teleport train momentum+teleport test 0 10 20 30 40 50 60 70 RMSprop train RMSprop test RMSprop+teleport train RMSprop+teleport test 0 10 20 30 40 50 60 70 Adam train Adam test Adam+teleport train Adam+teleport test Figure 13: Runtime comparison for integrating teleportation into various algorithms. Solid line represents average training loss, and dashed line represents average test loss. Shaded areas are 1 standard deviation of the test loss across 5 runs. The plots look almost identical to Figure 5, indicating that the cost of teleportation is negligible compared to gradient descents.