# generalized_federated_learning_via_sharpness_aware_minimization__40693120.pdf Generalized Federated Learning via Sharpness Aware Minimization Zhe Qu * 1 Xingyu Li * 2 Rui Duan 1 Yao Liu 3 Bo Tang 2 Zhuo Lu 1 Federated Learning (FL) is a promising framework for performing privacy-preserving, distributed learning with a set of clients. However, the data distribution among clients often exhibits non-IID, i.e., distribution shift, which makes efficient optimization difficult. To tackle this problem, many FL algorithms focus on mitigating the effects of data heterogeneity across clients by increasing the performance of the global model. However, almost all algorithms leverage Empirical Risk Minimization (ERM) to be the local optimizer, which is easy to make the global model fall into a sharp valley and increase a large deviation of parts of local clients. Therefore, in this paper, we revisit the solutions to the distribution shift problem in FL with a focus on local learning generality. To this end, we propose a general, effective algorithm, Fed SAM, based on Sharpness Aware Minimization (SAM) local optimizer, and develop a momentum FL algorithm to bridge local and global models, Mo Fed SAM. Theoretically, we show the convergence analysis of these two algorithms and demonstrate the generalization bound of Fed SAM. Empirically, our proposed algorithms substantially outperform existing FL studies and significantly decrease the learning deviation. 1. Introduction Federated Learning (FL) (Mc Mahan et al., 2017) is a collaborative training framework that enables a large number of clients, which can be phones, network sensors, or alternative local information sources (Kairouz et al., 2019; Mohri *Equal contribution 1Department of Electrical Engineering, University of South Florida, Tampa, USA 2Department of Electrical and Computer Engineering, Mississippi State University, Starkville, USA 3Department of Computer Science and Engineering, University of South Florida, Tampa, USA. Correspondence to: Zhe Qu . Proceedings of the 39 th International Conference on Machine Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copyright 2022 by the author(s). et al., 2019). FL trains machine learning models without transmitting client data over the network, and thus it can protect data privacy at some basic levels. Two important settings are introduced in FL (Kairouz et al., 2019): the cross-device FL and the cross-silo FL. The cross-silo FL is related to a small number of reliable clients, e.g., medical or financial institutions. By contrast, the cross-device FL includes a large number of clients, e.g., billion-scale android phones (Hard et al., 2018). In cross-device FL, clients are usually deployed in various environments. It is unavoidable that the distribution of the local dataset of each client varies considerably and incurs a distribution shift problem, highly degrading the learning performance. Many existing FL studies focus on the distribution shift problem mainly based on the following three directions: (i) The most popular solution to address this problem is to set the number of local training epochs performed between each communication round (Li et al., 2020b; Yang et al., 2021). (ii) Many algorithmic solutions in (Li et al., 2018b; Karimireddy et al., 2020; Acar et al., 2021) mainly focus on mitigating the influence of heterogeneity across clients via giving a variety of proximal terms to control the local model updates close to the global model. (iii) Knowledge distillation based techniques (Lin et al., 2020; Gong et al., 2021; Zhu et al., 2021) aggregate locally-computed logits for building global models, helping eliminate the need for each local model to follow the same architecture to the global model. Motivation. In centralized learning, the network generalization technique has been well studied to overcome the overfitting problem (Lakshminarayanan et al., 2017; Woodworth et al., 2020). Even in standard settings where the training and test data are drawn from a similar distribution, models still overfit the training data and the training model will fall into a sharp valley of the loss surface by using Empirical Risk Minimization (ERM) (Chaudhari et al., 2019). This effect is further intensified when the training and test data are of different distributions. Similarly, in FL, overfitting the local training data of each client is detrimental to the performance of the global model, as the distribution shift problem creates conflicting objectives among local models. The main strategy to improve the FL performance is to mitigate the local models to the global model from the average perspective (Karimireddy et al., 2020; Yang et al., Generalized Federated Learning via Sharpness Aware Minimization 2021; Li et al., 2018b), which has been demonstrated to accelerate the convergence of FL. However, fewer existing FL studies focus on how to protect the learning performance of the clients with poor performance, and hence parts of clients may lose their unique properties and incur large performance deviation. Therefore, a focus on improving global model generality should be of primary concern in the presence of the distribution shift problem. Improving local training generality would inherently position the objective of the clients closer to the global model objective. Recently, efficient algorithms Sharpness Aware Minimization (SAM) (Foret et al., 2021) have been developed to make the surface of loss function more smooth and generalized. It does not need to solve the min-max objectives as adversarial learning (Goodfellow et al., 2014; Shafahi et al., 2020); instead, it leverages linear approximation to improve the efficiency. As we discussed previously, applying SAM to be the local optimizer for generalizing the global model in FL should be an effective approach. We first introduce a basic algorithm adopting SAM in FL settings, called Fed SAM, where each local client trains the local model with the same perturbation bound. Although Fed SAM can help to make the global model generalization and improve the training performance, they do not affect the global model directly. In order to bridge the smooth information on both local and global models without accessing others private data, we develop our second and more important algorithm in our framework, termed Momentum Fed SAM (Mo Fed SAM) by additionally downloading the global model updates of the previous round, and then letting clients perform local training on both local dataset and global model updates by SAM. Contributions. We summarize our contributions as follows: (1) We approach one of the most troublesome cross-device FL challenges, i.e., distribution shift caused by data heterogeneity. To generalize the global model, we first propose a simple algorithm Fed SAM performing SAM to be the local optimizer. (2) We prove the convergence results O( L RS ) for Fed SAM algorithm, which matches the best convergence rate of existing FL studies. For the part of local training in the convergence rate, our proposed algorithms show speedup. Moreover, the generalization bound of Fed SAM is also presented. (3) To directly smooth the global model, we develop Mo Fed SAM algorithm, which performs local training with both local dataset and global model updates by SAM optimizer. Then, we present the convergence rates are O( βL RKN ) and O( βK RS ) on full and partial client participation strategies, which achieves speedup and implies that Mo Fed SAM is a more efficient algorithm to address the distribution shift problem. Related work. In this paper, we aim to evaluate and distinguish the generalization performance of clients. Throughout this paper, we only focus on the classic cross-device FL setting (Mc Mahan et al., 2017; Li et al., 2018b; Karimireddy et al., 2020) in which a single global model is learned from and served to all clients. In the Personalized FL (PFL) setting (T Dinh et al., 2020; Fallah et al., 2020; Singhal et al., 2021), the goal is to learn and serve different models for different clients. While related, our focus and contribution are orthogonal to personalization. In fact, our proposed algorithms are easy to extend to the PFL setting. For example, by solving a hyperparameter to control the interpolation between local and global models (Deng et al., 2020; Li et al., 2021), the participating clients can be defined as the clients that contribute to the training of the global model. We can use SAM to develop the global model and generate the local models by ERM to improve the performance. Momentum FL is an effective way to address the distribution shift problem and accelerate the convergence, which is based on injecting the global information into the local training directly. Momentum can be set on the server (Wang et al., 2019; Reddi et al., 2020), client (Karimireddy et al., 2021; Xu et al., 2021) or both (Khanduri et al., 2021). As we introduce previously, while these algorithms accelerate the convergence, the global model will locate in a sharp valley and overfit. As such, the global model may not be efficient for all clients and generate a large deviation. We propose to train global models using a set of participating clients and examine their performance both on training and validation datasets. In the centralized learning, some studies(Keskar et al., 2016; Lakshminarayanan et al., 2017; Woodworth et al., 2020) consider the out-of-distribution generalization problem, which shows on centrally trained models that even small deviations in the morphology of deployment examples can lead to severe performance degradation. The sharpness minimization is an efficient way to deal with this problem (Foret et al., 2021; Kwon et al., 2021; Zhuang et al., 2022; Du et al., 2021a). The FL setting differs from these other settings in that our problem assumes data is drawn from a distribution of client distributions even if the union of these distributions is stationary. Therefore, in FL settings, we consider the training performance and validation. It incurs more challenges than centralized learning. Although some studies develop algorithms to generalize the global model in FL (Mendieta et al., 2021; Yuan et al., 2021; Yoon et al., 2021), they lack theoretical analysis of how the proposed algorithm can improve the generalization and may incur privacy issues. A recent study (Caldarola et al., 2022) shows via empirical experiments that using SAM to be the local optimizer can improve the generalization of FL. Generalized Federated Learning via Sharpness Aware Minimization 2. Preliminaries and Proposed Algorithms 2.1. Fed Avg Algorithm Consider a FL setting with a network including N clients connected to one aggregator. We assume that for every i [N] the i-th client holds m training data samples ξi = (Xi, Y ) drawn from distribution Di. Note that Di may differ across different clients, which corresponds to client heterogeneity. Let Fi(w) be the training loss function of the client i, i.e., Fi(w) Eξi Di[Li(w, ξi)], where Li(w, ξi) is the per-data loss function. The classical FL problem (Mc Mahan et al., 2017; Li et al., 2020b; Karimireddy et al., 2020) is to fit the best model w to all samples via solving the following empirical risk minimization (ERM) problem on each client: i [N] Fi(w) . (1) where F(w) is the loss function of the global model. Fed Avg (Mc Mahan et al., 2017) is one of the most popular algorithms to address (1). In the communication round r, the server randomly samples Sr clients with the number of S and downloads the global model wr to them. After receiving the global model, these sampled clients run K times local Stochastic Gradient Descent (SGD) epochs using their local dataset in parallel, and upload the local model updates wr i,K to the server. When the server receives all the local model updates, it averages these to obtain the new global model wr+1. The pesudocode of Fed Avg is shown in Algorithm 1. 2.2. Fed SAM Algorithm Statistically heterogeneous local training dataset across the clients is one of the most important problems in FL studies. By capturing the Non-IID nature of local datasets in FL, the common assumption in existing FL studies (Mohri et al., 2019; Li et al., 2020a; Karimireddy et al., 2020; Reisizadeh et al., 2020) considers that the data samples of each client have a local distribution shift from a common unknown mixture distribution D, i.e., Di = D. While training via minimizing ERM by SGD searches for a single point w with a low loss, which can perfectly fit the distribution D, it often falls into a sharp valley of the loss surface (Chaudhari et al., 2019). As a result, the global model w may be biased to parts of clients (i.e., low heterogeneity compared to the mixture distribution D) and cannot guarantee enough generalization that makes all clients perform well. Moreover, since the training dataset distribution of each client may be different from the validation dataset with high probability, i.e., Dtra i = Dval i , and the validation dataset cannot be accessible during the training process, the global model w may not guarantee the learning performance of every client even for the clients working well during the training process. To address this problem, some FL algorithms with fairness guarantee have been developed (Li et al., 2020a; Du et al., 2021b), but they only consider the learning performance from the average perspective and do not protect the worse clients. In order to focus on the average and deviation for all clients at the same time, it is necessary to create a more general global model to serve all clients. Instead of searching for a single point solution such as ERM, the state-of-the-art algorithm Sharpness Aware Minimization (SAM) (Foret et al., 2021) aims to seek a region with low loss values via adding a small perturbation to the models, i.e., w + δ with less performance degradation. Due to the linear property of the FL optimization in (1), it is not difficult to observe that training the perturbed loss via SAM, i.e., w = w + δi, on each client should reduce the impact on the distribution shift and improve the generalization of the global model. Based on this observation, we design a more general FL algorithm called Fed SAM in this paper. The optimization of Fed SAM is formulated as follows: min w max δi 2 2 ρ i [N] fi( w) , (2) where f( w) max δ ρ F(w + δ), fi( w) max δi ρ Fi(w+δi), ρ is a predefined constant controlling the radius of the perturbation and 2 2 is a l2-norm, which will be simplified to in the rest paper. Next, we take a close look at the local perturbed loss function Fi(w + δi) and introduce how to use SAM to approach it. For a small value of ρ, using first order Taylor expansion around w, the inner maximization in (2) turns into the following linear constrained optimization: δi = argmax δi ρ Fi(w + δi) argmax δi ρ Fi(w) + δ i Fi(w) + O(ρ2) = ρsign( Fi(w)) Fi(w) where sign( ) denotes element-wise signum function. Therefore, the local optimizer of Fed SAM changes to minw Fi(w) = min w fi( w), where w w + ρ Fi(w) Fi(w) . We call w is the perturbed model with the highest loss within the neighborhood. Local SAM optimizer solves the minmax problem by iteratively applying the following two-step procedure for epoch k = 0, . . . , K 1 in communication round r: wr i,k = wr i,k + ρ gr i,k gr i,k wr i,k+1 = wr i,k ηl gr i,k, (4) where ηl is the learning rate of local model updates on each client, gr i,k = Fi(wr i,k, ξr i ) of Fi(wi,k) and gr i,k = Generalized Federated Learning via Sharpness Aware Minimization Algorithm 1 Fed Avg and Fed SAM Initialization: w0, ρ0 0 = 0, learning rates ηl, ηg and the number of epochs K. for r = 0, . . . , R 1 do Sample subset Sr [N] of clients. wt i,0 = wr. for each client i Sr in parallel do for k = 0, . . . , K 1 do Compute a local training estimate gr i,k = Fi(wr i,k, ξr i,k) of Fi(wr i,k). wr i,k = wr i,k ηlgr i,k. Compute local model wr i,k from (4). end for r i = wr i,K wr. end for r+1 = 1 i Sr r i . wr+1 = wr + ηg r. end for f( wr i,k, ξr i ) of fi( wr i,k). We can see that from (4), local training of each client estimates the point wr i,k +δr i at which the local loss is maximized around wr i,k in a region with a fixed perturbed radius approximately by using gradient ascent, and calculates gradient descent at wr i,k based on the gradient at the maximum point wr i,k + δr i . To present the difference between Fed Avg and Fed SAM, we summarize the training procedures in Algorithm 1. SAM optimizer comes from the similar idea of adversarial training, and it has been used in FL (Reisizadeh et al., 2020) called Fed Robust. It is based on solving min-max objectives, which brings up more computational cost for local training and the worse convergence performance than our proposed algorithms. We will show the comparison both on theoretical and empirical perspectives. Remark 2.1 Here, we briefly mention the SAM local optimizer can improve the generalization and help convergence from the smoothness perspective. We assume that the local loss function f(w) is L-smooth. Clearly, the loss function f is smoother, when L is smaller. Assume that f( w) is G-Lipschitz continuous, and δ N(0, ϵ2I), by leveraging (Nesterov & Spokoiny, 2017), we obtain that the perturbed loss function f( w) of Fed SAM is 2G ϵ -smooth. Based on the analysis in (Lian et al., 2017; Goyal et al., 2017), the best convergence rate should be 1 L. For SGD based FL with the original loss surface, L can be very high (even close to + due to the non-smooth nature of the Re LU activation). Obviously, L of the perturbed loss f( w) in Fed SAM should be much smaller due to the loss region. This can explain the intuition why increasing smoothness can significantly improve the convergence of FL. 3. Theoretical Analysis In what follows, we show the convergence results of Fed SAM algorithm for general non-convex FL settings. In order to propose the convergence analysis, we first state our assumptions as follows. Assumption 1 (Smoothness). fi is L-smooth for all i [N], i.e., fi(w) fi(v) L w v , for all w, v in its domain and i [N]. Assumption 2 (Bounded variance of global gradient without perturbation). The global variability of the local gradient of the loss function without perturbation δi is bounded by σ2 g, i.e., Fi(wr) F(wr) 2 σ2 g, for all i [N] and r. Assumption 3 (Bounded variance of stochastic gradient). The stochastic gradient fi(w, ξi), computed by the i-th client of model parameter w using mini-batch ξi is an unbiased estimator Fi(w) with variance bounded by σ2, i.e., Fi(w, ξi) Fi(w, ξi) Fi(w) Fi(w) i [N], where the expectation is over all local datasets. Assumptions 1 and 2 are standard in general non-convex FL studies (Li et al., 2020b; Karimireddy et al., 2020; Reddi et al., 2020; Karimireddy et al., 2021; Yang et al., 2021) in order to assume the loss function continuous and bound the heterogeneity of FL systems. Note that we consider that σ2 g mainly depends on the data-heterogeneity, and the perturbation should be calculated. Hence, we only bound it without perturbation. We will present the upper bound of fi( wr) f( wr) 2 in Appendix A. Assumption 3 bounds the variance of stochastic gradient. Although many FL studies use the similar assumption to bound the stochastic gradient variance (Li et al., 2020b; Karimireddy et al., 2020), the definition is Eξi Fi(w, ξi) Fi(w) 2 σ2 l , which is not easy to measure the value of σ2 l , and the upper bound of σ2 l may be closed to + . In this paper, Assumption 3 is considered as the norm of difference in unit vectors that can be upper bounded by the arc length on a unit circle. Therefore, σ2 should be less than π2. Clearly, this assumption is tighter than existing FL studies. 3.1. Convergence Analysis of Fed SAM We now state our convergence results for Fed SAM algorithm. The detailed proof is in Appendix B. Generalized Federated Learning via Sharpness Aware Minimization Theorem 3.1 Let the learning rates be chosen as ηl = O( 1 KN and the perturbation amplitude ρ proportional to the learning rate, e.g., ρ = O( 1 R). Under Assumptions 1-3 and full client participation, the sequence of iterates generated by Fed SAM in Algorithm 1 satisfies: RKN + σ2 g R + L2σ2 l R3/2 where F = f( w0) f( w ) and f( w ) = min w f( w). For the partial client participation strategy and S K, if we choose the learning rates ηg = O( 1 KS and ρ = O( 1 R), the sequence of iterates generated by Fed SAM in Algorithm 1 satisfies: Remark 3.2 For the full and partial client participation strategies of Fed SAM algorithm in this theorem, the dominant terms of the convergence rate are O( L SR) by properly choosing the learning rates ηl and ηg, which match the best convergence rate in existing general non-convex FL studies (Karimireddy et al., 2020; Yang et al., 2021; Acar et al., 2021). Since the convergence rate structures in this theorem of these two strategies are similar, it indicates that uniformly sampling does not result in fundamental changes of convergence. In addition, both convergence rates include four main terms with an additional term compared to (Karimireddy et al., 2020; Yang et al., 2021; Acar et al., 2021). Note that we only show the dominant part of each term in the main paper. The detailed proof can be found in Appendix. Remark 3.3 The additional term O( L2 R2 ) comes from the additional SGD step for smoothness via SAM local optimizer in (4). However, this term can be negligible due to its higher order. More specifically, since the smoothness is due to the local training, we can combine it with the local training term, i.e., O( σ2 R3/2K + 1 R2 ). Clearly, this term also achieves speedup than the existing best rate, i.e., O( 1 R). For the partial participation strategy, the dominant term is due to fewer clients participating and random sampling (heterogeneity), i.e., O( KG2 RS ). The convergence rate improves substantially as the number of clients increases, which matches the results of partial client participation FL (Karimireddy et al., 2020; Yang et al., 2021; Acar et al., 2021). Intuitively, increasing the convergence rate of this term is because SAM optimizer can make the global model more generalization and reduce the distribution shift. Remark 3.4 The Fed Robust (Reisizadeh et al., 2020) algorithm is an adversarial learning framework in FL setting, which is based on the similar idea of Fed SAM. It has the convergence rate of O( Lf (RN)1/3 + L2 R1/3N2/3 ), and it does not perform well from the convergence perspective compared with Fed SAM. Since multiple gradient descents steps should be computed in each local training epoch, Fed Robust will waste more running time and computational cost to process the local training. 3.2. Generalization Bounds of Fed SAM Based on the margin-based generalization bounds in (Neyshabur et al., 2018; Bartlett et al., 2017; Farnia et al., 2018; Reisizadeh et al., 2020), we propose the generalization error of Fed SAM algorithm with the general neural network as follows: LSAM γ (F(w)) := 1 Fi(w + δi, X)[Y ] max j =Y Fi(w + δi, X)[j] γ . Here, Fi(w + δi, X) is the loss function solving by SAM local optimizer for client i in (2), X is an input, Pi is the probability of the underlying distribution of client i, and Fi(w + δi, X)[j] is the output of the last softmax layer for label j about the training neural network. It is worth noting that γ is a constant, and for γ = 0, (5) can be simplified to the average misclassification rate with the distribution shift, which is denoted by LSAM. In addition, we use ˆLSAM γ (w) as the above margin risk to represent the empirical distribution of training samples, and hence we use ˆPi to replace the underlying Pi to be the empirical probability, which is calculated by the m training samples on client i. The following theorem aims to bound the difference of the empirical and the margin-based error defined in (5) under a general deep neural network. We use the spectral norm based generalization bound framework (Neyshabur et al., 2018; Farnia et al., 2018; Chatterji et al., 2019) to prove the next theorem. In order to demonstrate the margin-based error bounds, we assume that the neural network with smooth Re LU activation functions θ are 1-Lipschitz activation functions. The detailed proof is shown in Appendix C. Theorem 3.5 Let input X be an n n image whose norm is bounded by A, f(w) be the classification function with d hidden-layer neural network with h units per hidden-layer, and satisfy 1-Lipschitz activation θ(0) = 0. We assume the constant M 1 for each layer Wj satisfies 1 M Wj M, where φw := (Qd j=1 Wj )1/d denotes the geometric mean of f(w) s spectral norms across all layers. Then, for any margin value γ, size of local training dataset on each client m, ζ > 0, with probability 1 ζ over the training set, any parameter of SAM local optimizer w = w + δ such Generalized Federated Learning via Sharpness Aware Minimization Algorithm 2 Mo Fed SAM algorithm. 1: Initialization: w0, 0 = 0, ρ0, momentum parameter β the number of local updates K. 2: for r = 0, . . . , R 1 do 3: Sample subset Sr [N] of clients. 4: wt i,0 = wr. 5: for each client i Sr in parallel do 6: for k = 0, . . . , K 1 do 7: Compute a local training estimate gr i,k = Fi(wr i,k, ξr i,k) of Fi(wr i,k). 8: Compute local model wr i,k from (6). 9: end for 10: r i = wr i,K wr. 11: end for 12: r+1 = 1 ηl KS P 13: wr+1 = wr ηg r+1. 14: end for that max X Di Fi(w) f( w) γ 8 , we can obtain the following generalization bound: LSAM(F(w)) ˆLSAM γ (F(w + δ)) + O 32Ad2h log(dh)Q(F(w)) + d log Nmd log(M) where Q(F(w)) := Qd j=1 Wj Pd i=1 Wj 2 F Wj and Wj 2 F is the Frobenius norm. Theorem 3.5 proposes a non-asymptotic bound on the generalization risk of Fed SAM for general neural networks. The PAC-Bayesian bounds of SAM (Foret et al., 2021; Kwon et al., 2021; Zhuang et al., 2022; Du et al., 2021a) does not provide the insight about the underlying reason that results in generalization, i.e., how to choose the value of λ in the Gaussian noise N(0, λI) to be the perturbation. In Theorem 3.5, we present the dependence of the perturbation δ and the different neural network parameters in which we can enforce the loss surface around a point in order to guarantee the smoothness. 4. Momentum Fed SAM (Mo Fed SAM) 4.1. Algorithm of Mo Fed SAM Since r serves as the direction for the global model, while Fed SAM algorithm achieves efficient convergence rate theoretically, the influence of local optimizer cannot directly affect the global model, i.e., the term including σ2 g in the convergence rate. Note that r aggregates the global model information of participating clients, and reusing this information should be useful to guide the local training on the participated clients in next communication round, which is similar to momentum FL (Wang et al., 2019; Reddi et al., 2020; Karimireddy et al., 2021; Khanduri et al., 2021; Xu et al., 2021). Inspired by this motivation, we now provide our second algorithm, termed Mo Fed SAM, which aims to smooth and generalize the global model directly. The training procedure of k-th local training epoch in round r is formulated as follows: wr i,k = wr i,k + ρ gr i,k gr i,k vr i,k = β gr i,k + (1 β) r wr i,k = wr i,k ηlvr i,k, where β is the momentum rate. If β = 1, Mo Fed SAM is equivalent to Fed SAM. From (6), we can see that the global model information r directly contributes the local training epoch, since wr i,k includes gr i,k and r at the same time. Therefore, it indicates that Mo Fed SAM make the local and global models smoothness at the same time. Especially, even if only a subset of clients are sampled in each communication round, the information of gradients of previous local model updates can be still contained in r. Therefore, Mo Fed SAM also works well of partial client participation FL. More specifically, the global model information term r is considered as an approximation to the gradient of the global model f( w), i.e., r f( wr). One advantage is that Mo Fed SAM adds a correction term to the local gradient direction, and it also asymptotically aligns with the difference between global and local gradient. It is worth noting that we use (G, B)-BGD in Assumption 2 to prove the convergence rate, which is tighter than (G, 0)-BGD in (Xu et al., 2021). 4.2. Convergence Analysis of Mo Fed SAM Next theorem is the convergence rate of Mo Fed SAM algorithm, and the detailed proof is in Appendix D. Theorem 4.1 Let the learning rates be chosen as ηl = O( 1 RβKL), ηg = O( RβL) and the perturbation ampli- tude ρ proportional to the learning rate, e.g., ρ = O( 1 R). Under the Assumptions 1-3, any momentum parameter β 1 2 and the full client participation strategy, the sequence of { wr} generated by Mo Fed SAM in Algorithm 2 satisfies: RKN + βσ2 g RL2 + Lσ2 where F = f( w0) f( w ) and f( w ) = min w f( w). For the partial client participation strategy, if we choose the learning rates ηg = O( 1 RβKL), ηg = O( R), the following convergence holds: Generalized Federated Learning via Sharpness Aware Minimization 0 25 50 75 100 125 150 175 200 Communication Round Test Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Fed GSAM Mo Fed SAM (a) EMNIST dataset. 0 100 200 300 400 500 600 700 800 Communication Round Test Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Fed GSAM Mo Fed CM (b) CIFAR-10 dataset. 0 100 200 300 400 500 600 700 800 Communication Round Test Accuracy Fed Avg SCAFFOLD Fed Robust SCAFFOLD Mime Lite Fed SAM Fed GSAM Mo Fed SAM (c) CIFAR-100 dataset. Figure 1. Testing accuracy on different datasets. Table 1. Average (standard deviation) training accuracy and testing accuracy. Communication round to achieve the targeted testing accuracy: EMNIST 80%, CIFAR-10 80% and CIFAR-100 50%. Algorithm EMNIST CIFAR-10 CIFAR-100 Train Validation Round Train Validation Round Train Validation Round Fed Avg 95.07 (0.94) 84.38 (4.03) 43 93.15 (1.44) 81.87 (5.09) 307 79.57 (1.84) 53.57 (5.40) 302 SCAFFOLD 93.85 (1.31) 84.09 (4.56) 69 91.76 (1.89) 80.61 (5.64) 546 78.49 (2.02) 51.49 (5.87) 551 Fed Robust 93.17 (0.62) 83.70 (3.37) 91 90.82 (1.27) 79.63 (4.21) 847 76.80 (1.70) 49.06 (4.75) 893 Fed CM 96.16 (1.14) 84.85 (4.11) 28 95.61 (1.50) 83.30 (4.77) 136 82.13 (1.96) 55.50 (5.04) 182 Mime Lite 96.22 (1.16) 84.88 (4.22) 25 95.73 (1.56) 83.18 (4.65) 152 82.46 (2.00) 55.73 (5.11) 189 Fed SAM 95.73 (0.49) 84.75 (3.04) 38 94.20 (1.08) 83.06 (3.87) 269 81.04 (1.59) 54.69 (4.36) 245 Mo Fed SAM 96.42 (0.42) 85.07 (2.95) 24 95.67 (1.16) 83.92 (3.65) 124 82.62 (1.53) 56.60 (4.42) 124 Remark 4.2 When T is sufficiently large compared to K, convergence rates under full and partial client participation strategies of Mo Fed SAM algorithm are O( βL RKN + β RL2 ) RS ). The momentum parameter β is small enough, i.e., 0.1 (Karimireddy et al., 2021; Xu et al., 2021), from which the effect is important for convergence, due to the number of local epochs setting less than 20 in usual (Reddi et al., 2020; Yang et al., 2021; Acar et al., 2021). Therefore, our convergence results achieve speedup compared with Fed SAM. We also note that the convergence related to the local training is O( L R2β + βL2 R2 ) and O( L2 S ), where the second part comes from sharpness, and it can be negligible. From the convergence analysis of Fed CM (Xu et al., 2021), i.e., O( R + L β2/3R2/3 ), we can see that Mo Fed SAM achieves speedup both on the dominant part and local training part. The analysis indicates the benefit of bridging the sharpness between local and global models. 5. Experiments We evaluate our proposed algorithms on extensive and representative datasets and learning models to date. To accomplish this, we conduct experiments on three learning models across three datasets comparing to five FL benchmarks with varying different parameters. 5.1. Experimental Setup Benchmarks and hyper-parameters. We consider five FL benchmarks: without momentum FL Fed Avg (Mc Mahan et al., 2017), SCAFFOLD (Karimireddy et al., 2020), Fed Robust (Reisizadeh et al., 2020); momentum FL Mime Lite (Karimireddy et al., 2021) and Fed CM (Xu et al., 2021). The learning rates are individually tuned and other optimizer hyper-parameters such as ρ = 0.5 for SAM and β = 0.1 for momentum, unless explicitly stated otherwise. We refer to Appendices E-F for detailed experimental setup and additional ablation studies. Datasets and models. We use three images datasets: EMNIST (Cohen et al., 2017), CIFAR-10, and CIFAR-100 (Krizhevsky et al., 2009). Our cross-device FL setting includes 100 clients in total with participation rate 20%. In each communication round, each client is sampled independently of each other, with probability 0.2. We simulate the data heterogeneity by sampling the label ratios from a Dirchlet distribution with parameter 0.6 (Acar et al., 2021), the number of local epochs is set as K = 10 by default. We adopt two learning models on each dataset: (i) CNN on EMNIST with batch 32 and (ii) Res Net-18 (He et al., 2016) on CIFAR-10 and CIFAR-100 with batch 128. The detailed experimental setup and other additional experiments and ablation studies will be shown in Appendices E-F. Generalized Federated Learning via Sharpness Aware Minimization Table 2. Impact of the heterogeneity on CIFAR-10 dataset (IID, Dirichlet 0.6 and Dirichlet 0.3). Algorithm IID Dirichlet 0.6 Dirichlet 0.3 Train Validation Round Train Validation Round Train Validation Round Fed Avg 94.95 (1.01) 85.97 (3.53) 238 93.15 (1.44) 81.87 (5.09) 307 91.89 (1.63) 77.39 (5.62) - SCAFFOLD 93.04 (1.13) 83.82 (3.72) 290 91.76 (1.89) 80.61 (5.64) 546 90.02 (2.08) 75.67 (5.93) - Fed Robust 91.63 (0.91) 82.44 (3.15) 361 90.82 (1.27) 79.63 (4.21) 847 89.72 (1.42) 73.11 (5.11) - Fed CM 97.02 (1.10) 88.14 (3.33) 87 95.61 (1.50) 83.30 (4.77) 136 93.88 (1.67) 81.34 (5.50) 583 Mime Lite 97.16 (1.08) 88.53 (3.53) 82 95.73 (1.56) 83.18 (4.65) 152 93.97 (1.72) 81.83 (5.53) 548 Fed SAM 95.42 (0.81) 87.36 (2.85) 205 94.20 (1.08) 83.06 (3.87) 269 92.90 (1.26) 79.82 (4.98) 816 Mo Fed SAM 97.22 (0.88) 88.96 (2.94) 75 95.67 (1.16) 83.92 (3.65) 124 94.12 (1.31) 83.35 (5.06) 490 1.00 0.75 0.50 0.25 0.00 0.25 0.50 0.75 1.00 1.00 1.000.75 0.50 0.250.00 0.25 0.50 0.75 1.00 1.00 (a) Fed Avg. 1.00 0.75 0.50 0.25 0.00 0.25 0.50 0.75 1.00 1.00 1.000.75 0.50 0.250.00 0.25 0.50 0.75 1.00 1.00 (b) Fed SAM. 1.00 0.75 0.50 0.25 0.00 0.25 0.50 0.75 1.00 1.00 4.500 5.000 5.500 1.000.75 0.50 0.250.00 0.25 0.50 0.75 1.00 1.00 (c) Mo Fed SAM. Figure 2. Loss surface of Fed Avg, Fed SAM and Mo Fed SAM algorithm with Res Net-18 on CIFAR-10 dataset. 0.2 0.4 0.6 0.8 1.0 Number of participated clients S Test Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Fed GSAM Mo Fed SAM (a) Impact of S. 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 Local epochs K Test Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Fed GSAM Mo Fed SAM (b) Impact of K. 0.2 0.4 0.6 0.8 1.0 Test Accuracy Fed SAM Fed GSAM Mo Fed SAM (c) Impact of ρ. 0.0 0.2 0.4 0.6 0.8 1.0 0.810 Validation Accuracy Fed CM Mime Lite Mo Fed SAM (d) Impact of β. Figure 3. Impacts of different parameters on CIFAR-10 dataset. 5.2. Performance Evaluation (1) Performance with compared benchmarks. We first investigate the effect of our proposed algorithms with compared benchmarks on different datasets in Figure 1 and Table 1. From these results, we can clearly see that for the performance of without momentum FL: Fed SAM > Fed Avg > SCAFFOLD > Fed Robust, and the performance momentum FL: Mo Fed SAM > Mime Lite > Fed CM. Our proposed algorithms outperform other benchmarks both on accuracy and convergence perspectives. We do not compare the FL algorithms with momentum FL, since momentum FL is required to transmit more information than FL, e.g., r+1. This is the reason why momentum FL outperforms FL benchmarks. More specifically, to present the generalization performance, we show the deviation, i.e., best and worst local accuracy. In addition, the performance improve- ment on CIFAR-100 dataset is more obvious than others, since SAM optimizers perform more efficiently on more complicated datasets. (2) Impact of Non-IID levels. In Tables 2, 3, 4 and 5, we can see that our proposed algorithms outperforms the benchmarks across different client distribution levels on the same FL categories. We consider heterogeneous client distributions by varying balanced-unbalanced, number of clients and participation levels settings on various datasets. Client distributions become more non-IID as we go from IID, Dirichlet 0.6 to Dirichlet 0.3 splits which makes global optimization more difficult. For example, as non-IID levels increasing, Mo Fed SAM achieves a higher test accuracy 0.43%, 1.24% and 1.52% and saving communication round 7, 40, and 59 than Mime Lite on CIFAR-10 dataset. In summary, although almost all the algorithms perform well Generalized Federated Learning via Sharpness Aware Minimization enough for training dataset, the testing accuracy usually has a significant degradation especially the deviation of local clients. In Table 1, we can see that our proposed algorithms significantly decrease the deviation of local clients, which indicates that our proposed algorithms show enough generalization of the global model. (3) Loss surface visualization. To visualize the sharpness of the flat minima obtained by Fed Avg, Fed SAM and Mo Fed SAM, we show the loss surface, which are trained with Res Net-18 under the CIFAR-10 dataset. We display the loss surfaces in Figure 3, following the plotting algorithm in (Li et al., 2018a). The xand y-axes are two random sampled orthogonal Gaussian perturbations. We can clearly see that both Fed SAM and Mo Fed SAM improve the sharpness significantly in comparison to Fed Avg, which indicates that our proposed algorithms perform more generalization. (4) Impact of other parameters. Here, we show the impact of different parameters, e.g., number of participated clients S, number of epochs K, perturbation radius ρ for our proposed algorithms and momentum value β in Figures 3, 7, 8 and 9. Our proposed algorithms outperform the same FL categories, i.e., with or without momentum. Similar to existing FL studies, increasing batch size and number of participated clients can improve the learning performance. Increasing the number of epochs K cannot guarantee better accuracy substantially, however, all the benchmarks perform worst when K = 1. The best ρ for each dataset is different, the best performance of ρ value is set as 0.2 for EMNIST, 0.5 for CIFAR-10 and 0.6 for CIFAR-100. 6. Conclusion In this paper, we study the distribution shift coming from the data heterogeneity challenge of cross-device FL from a simple yet unique perspective by making global model generality. To this end, we propose two algorithms Fed SAM and Mo Fed SAM, which do not generate more communication costs compared with existing FL studies. By deriving the convergence of general non-convex FL settings, these algorithms achieve competitive performance. Furthermore, we also provide the generalization bound of Fed SAM algorithm. The extensive experiments strongly support that our proposed algorithms decrease the performance deviation among all local clients significantly. Acknowledgements The work at the University of South Florida was supported in part by NSF under Grant CNS-2044516. The work at the Mississippi State University was supported in part by NSF under Grant IIS-2047570. Acar, D. A. E., Zhao, Y., Matas, R., Mattina, M., Whatmough, P., and Saligrama, V. Federated learning based on dynamic regularization. In International Conference on Learning Representations, 2021. Bartlett, P. L., Foster, D. J., and Telgarsky, M. J. Spectrallynormalized margin bounds for neural networks. Advances in Neural Information Processing Systems, 30: 6240 6249, 2017. Caldarola, D., Caputo, B., and Ciccone, M. Improving generalization in federated learning by seeking flat minima. ar Xiv preprint ar Xiv:2203.11834, 2022. Chatterji, N., Neyshabur, B., and Sedghi, H. The intriguing role of module criticality in the generalization of deep networks. In International Conference on Learning Representations, 2019. Chaudhari, P., Choromanska, A., Soatto, S., Le Cun, Y., Baldassi, C., Borgs, C., Chayes, J., Sagun, L., and Zecchina, R. Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124018, 2019. Cohen, G., Afshar, S., Tapson, J., and Van Schaik, A. Emnist: Extending mnist to handwritten letters. In 2017 International Joint Conference on Neural Networks (IJCNN), pp. 2921 2926. IEEE, 2017. Deng, Y., Kamani, M. M., and Mahdavi, M. Adaptive personalized federated learning. ar Xiv preprint ar Xiv:2003.13461, 2020. Dieuleveut, A., Fort, G., Moulines, E., and Robin, G. Federated-em with heterogeneity mitigation and variance reduction. Advances in Neural Information Processing Systems, 34, 2021. Du, J., Yan, H., Feng, J., Zhou, J. T., Zhen, L., Goh, R. S. M., and Tan, V. Y. Efficient sharpness-aware minimization for improved training of neural networks. ar Xiv preprint ar Xiv:2110.03141, 2021a. Du, W., Xu, D., Wu, X., and Tong, H. Fairness-aware agnostic federated learning. In Proceedings of the 2021 SIAM International Conference on Data Mining (SDM), pp. 181 189. SIAM, 2021b. Fallah, A., Mokhtari, A., and Ozdaglar, A. Personalized federated learning with theoretical guarantees: A modelagnostic meta-learning approach. Advances in Neural Information Processing Systems, 33:3557 3568, 2020. Farnia, F., Zhang, J., and Tse, D. Generalizable adversarial training via spectral normalization. In International Conference on Learning Representations, 2018. Generalized Federated Learning via Sharpness Aware Minimization Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. Gong, X., Sharma, A., Karanam, S., Wu, Z., Chen, T., Doermann, D., and Innanje, A. Ensemble attention distillation for privacy-preserving federated learning. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 15076 15086, 2021. Goodfellow, I. J., Shlens, J., and Szegedy, C. Explaining and harnessing adversarial examples. ar Xiv preprint ar Xiv:1412.6572, 2014. Goyal, P., Doll ar, P., Girshick, R., Noordhuis, P., Wesolowski, L., Kyrola, A., Tulloch, A., Jia, Y., and He, K. Accurate, large minibatch sgd: Training imagenet in 1 hour. ar Xiv preprint ar Xiv:1706.02677, 2017. Hard, A., Rao, K., Mathews, R., Ramaswamy, S., Beaufays, F., Augenstein, S., Eichner, H., Kiddon, C., and Ramage, D. Federated learning for mobile keyboard prediction. ar Xiv preprint ar Xiv:1811.03604, 2018. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. Kairouz, P., Mc Mahan, H. B., Avent, B., Bellet, A., Bennis, M., Bhagoji, A. N., Bonawitz, K., Charles, Z., Cormode, G., Cummings, R., et al. Advances and open problems in federated learning. ar Xiv preprint ar Xiv:1912.04977, 2019. Karimireddy, S. P., Kale, S., Mohri, M., Reddi, S., Stich, S., and Suresh, A. T. Scaffold: Stochastic controlled averaging for federated learning. In International Conference on Machine Learning, pp. 5132 5143. PMLR, 2020. Karimireddy, S. P., Jaggi, M., Kale, S., Mohri, M., Reddi, S. J., Stich, S. U., and Suresh, A. T. Breaking the centralized barrier for cross-device federated learning. In Thirty-Fifth Conference on Neural Information Processing Systems, 2021. Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. On large-batch training for deep learning: Generalization gap and sharp minima. ar Xiv preprint ar Xiv:1609.04836, 2016. Khanduri, P., SHARMA, P., Yang, H., Hong, M., Liu, J., Rajawat, K., and Varshney, P. STEM: A stochastic twosided momentum algorithm achieving near-optimal sample and communication complexities for federated learning. In Advances in Neural Information Processing Systems, 2021. Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009. Kwon, J., Kim, J., Park, H., and Choi, I. K. Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In Proceedings of the 38th International Conference on Machine Learning, pp. 5905 5914. PMLR, 2021. Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in Neural Information Processing Systems, 30, 2017. Li, H., Xu, Z., Taylor, G., Studer, C., and Goldstein, T. Visualizing the loss landscape of neural nets. Advances in Neural Information Processing Systems, 31, 2018a. Li, T., Sahu, A. K., Zaheer, M., Sanjabi, M., Talwalkar, A., and Smith, V. Federated optimization in heterogeneous networks. ar Xiv preprint ar Xiv:1812.06127, 2018b. Li, T., Sanjabi, M., Beirami, A., and Smith, V. Fair resource allocation in federated learning. In International Conference on Learning Representations, 2020a. Li, T., Hu, S., Beirami, A., and Smith, V. Ditto: Fair and robust federated learning through personalization. In International Conference on Machine Learning, pp. 6357 6368. PMLR, 2021. Li, X., Huang, K., Yang, W., Wang, S., and Zhang, Z. On the convergence of fedavg on non-iid data. In International Conference on Learning Representations, 2020b. Lian, X., Zhang, C., Zhang, H., Hsieh, C.-J., Zhang, W., and Liu, J. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. ar Xiv preprint ar Xiv:1705.09056, 2017. Lin, T., Kong, L., Stich, S. U., and Jaggi, M. Ensemble distillation for robust model fusion in federated learning. In Advances in Neural Information Processing Systems, 2020. Mc Mahan, B., Moore, E., Ramage, D., Hampson, S., and y Arcas, B. A. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pp. 1273 1282. PMLR, 2017. Mendieta, M., Yang, T., Wang, P., Lee, M., Ding, Z., and Chen, C. Local learning matters: Rethinking data heterogeneity in federated learning. ar Xiv preprint ar Xiv:2111.14213, 2021. Mohri, M., Sivek, G., and Suresh, A. T. Agnostic federated learning. In International Conference on Machine Learning, pp. 4615 4625. PMLR, 2019. Generalized Federated Learning via Sharpness Aware Minimization Nesterov, Y. and Spokoiny, V. Random gradient-free minimization of convex functions. Foundations of Computational Mathematics, 17(2):527 566, 2017. Neyshabur, B., Bhojanapalli, S., and Srebro, N. A pac-bayesian approach to spectrally-normalized margin bounds for neural networks. In International Conference on Learning Representations, 2018. Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32:8026 8037, 2019. Reddi, S. J., Charles, Z., Zaheer, M., Garrett, Z., Rush, K., Koneˇcn y, J., Kumar, S., and Mc Mahan, H. B. Adaptive federated optimization. In International Conference on Learning Representations, 2020. Reisizadeh, A., Farnia, F., Pedarsani, R., and Jadbabaie, A. Robust federated learning: The case of affine distribution shifts. In Neur IPS, 2020. Shafahi, A., Najibi, M., Xu, Z., Dickerson, J., Davis, L. S., and Goldstein, T. Universal adversarial training. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pp. 5636 5643, 2020. Singhal, K., Sidahmed, H., Garrett, Z., Wu, S., Rush, J. K., and Prakash, S. Federated reconstruction: Partially local federated learning. In Advances in Neural Information Processing Systems, 2021. T Dinh, C., Tran, N., and Nguyen, T. D. Personalized federated learning with moreau envelopes. Advances in Neural Information Processing Systems, 33, 2020. Tropp, J. A. User-friendly tail bounds for sums of random matrices. Foundations of computational mathematics, 12 (4):389 434, 2012. Wang, J., Tantia, V., Ballas, N., and Rabbat, M. Slowmo: Improving communication-efficient distributed sgd with slow momentum. In International Conference on Learning Representations, 2019. Woodworth, B., Gunasekar, S., Lee, J. D., Moroshko, E., Savarese, P., Golan, I., Soudry, D., and Srebro, N. Kernel and rich regimes in overparametrized models. In Conference on Learning Theory, pp. 3635 3673. PMLR, 2020. Xu, J., Wang, S., Wang, L., and Yao, A. C.-C. Fedcm: Federated learning with client-level momentum. ar Xiv preprint ar Xiv:2106.10874, 2021. Yang, H., Fang, M., and Liu, J. Achieving linear speedup with partial worker participation in non-IID federated learning. In International Conference on Learning Representations, 2021. Yoon, T., Shin, S., Hwang, S. J., and Yang, E. Fedmix: Approximation of mixup under mean augmented federated learning. In International Conference on Learning Representations, 2021. Yuan, H., Morningstar, W., Ning, L., and Singhal, K. What do we mean by generalization in federated learning? ar Xiv preprint ar Xiv:2110.14216, 2021. Zhu, Z., Hong, J., and Zhou, J. Data-free knowledge distillation for heterogeneous federated learning. ar Xiv preprint ar Xiv:2105.10056, 2021. Zhuang, J., Gong, B., Yuan, L., Cui, Y., Adam, H., Dvornek, N. C., sekhar tatikonda, s Duncan, J., and Liu, T. Surrogate gap minimization improves sharpness-aware training. In International Conference on Learning Representations, 2022. Generalized Federated Learning via Sharpness Aware Minimization A. Preliminary Lemmas For giving the theoretical analysis of the convergence rate of all proposed algorithms, we firstly state some preliminary lemmas as follows: Lemma A.1 (Relaxed triangle inequality). Let {v1, . . . , vτ} be τ vectors in Rd. Then, the following are true: (1) vi + vj 2 (1 + a) vi 2 + (1 + 1 a) vj 2 for any a > 0, and (2) Pτ i=1 vi 2 τ Pτ i=1 vi 2. Lemma A.2 For random variables x1, . . . , xn, we have E[ x1 + + xn 2] n E[ x1 2 + + xn 2]. Lemma A.3 For independent, mean 0 random variables x1, . . . , xn, we have E[ x1 + + xn 2] = E[ x1 2 + + xn 2]. Lemma A.4 (Separating mean and variance for SAM). The stochastic gradient Fi(w, ξi) computed by the i-th client at model parameter w using minibatch ξ is an unbiased estimator of Fi(w) with variance bounded by σ2. The gradient of SAM is formulated by k=0 E[ Fi(wi r,k) 2] + KL2ρ2 k=0 E[ Fi(wi r,k) 2] + KL2ρ2σ2 l . Proof. For the first inequality, we can bound as follows k=0 (gr i,k F(wr i,k)) k=0 E[ gr i,k 2] + L2 K 1 X i [N] (wr i,k + δr i,k( wr i,k; ξr i,k) wr i,k δr i,k( wr i,k)) k=0 E[ gr i,k 2] + KL2ρ2σ2 l N . where (a) is from Assumption 1 and (b) is from Assumption 3 and Lemma A.3. Similarly, we can obtain the second inequality, and hence we omit it here. Lemma A.5 (Bounded global variance of Fi(w + δi) F(w + δ) 2.) An immediate implication of Assumptions 1 and 2, the variance of local and global gradients with perturbation can be bounded as follows: Fi(w + δi) F(w + δ) 2 3σ2 g + 6L2ρ2. fi( w) f( w) 2 = Fi(w + δi) F(w + δ) 2 = Fi(w + δi) Fi(w) + Fi(w) F(w) + F(w) F(w + δ) 2 (a) 3 Fi(w + δi) Fi(w) 2 + 3 Fi(w) F(w) 2 + 3 F(w) F(w + δ) 2 (b) 3σ2 g + 6L2ρ2, where (a) is from Lemma A.2 and (b) is from Assumption 1, 2 and the perturbation is bounded by ρ. Generalized Federated Learning via Sharpness Aware Minimization Algorithm 3 Fed SAM: Federated Sharpness Aware Minimization 1: Initialization: w0, ρ0, γ the number of local updates K, batch size b, local learning ηl and global learning rate ηg. 2: for each round r = 0, . . . , R 1 do 3: Sample subset Sr [N] of clients. 4: communicate wr to all clients i Sr. 5: for each client i Sr in parallel do 6: initialize local model wr i,0 wr. 7: for k = 0, . . . , K 1 do 8: Compute gr i,k 1 by taking an estimation Fi(wr i,k 1, ξr i ) of Fi(wr i,k 1). 9: wr i,k 1 = wr i,k 1 + ρ gr i,k 1 gr i,k 1 . 10: Compute gr i,k 1 by taking an estimation fi( wr i,k 1, ξr i ) of fi( wr i,k 1, ξr i ). 11: wr i,k = wr i,k 1 ηl gr i,k 1. 12: end for 13: r i = wr i,K wr. 14: end for 15: r+1 = 1 i Sr r i . 16: wr+1 = wr + ηg r. 17: end for B. Convergence Analysis for Fed SAM B.1. Description of Fed SAM Algorithm and Key Lemmas We outline the Fed SAM algorithm in Algorithm 3. In round r, we sample Sr [N] clients with |Sr| = S and then perform the following updates: Starting from the shared global parameters wr i,0 = wr 1, we update the local parameters for k [K] wr i,k = wr i,k 1 + ρ gr i,k 1 gr i,k 1 wr i,k = wr i,k 1 ηl gr i,k 1, After K times local epochs, we obtain the following r i = wr i,K wr. (7) Compute the new global parameters using only updates from the clients i Sr and a global step-size ηg: wr+1 = wr + ηg r. Lemma B.1 (Bounded Eδ of Fed SAM). Suppose our functions satisfies Assumptions 1-2. Then, the updates of Fed SAM for any learning rate satisfying ηl 1 4KL have the drift due to δi,k δ: i E[ δi,k δ 2] 2K2β2η2 l ρ2. Proof. Recall the definitions of δ and δi,k as follows: F(w) , δi,k = ρ Fi(wi,k, ξi) Fi(wi,k, ξi) . Generalized Federated Learning via Sharpness Aware Minimization If the local learning rate ηl is small, the gradient of one epoch Fi(wi,k, ξi) is small. Based on the first order Hessian approximation, the expected gradient is Fi(wi,k) = Fi(wi,k 1 + gi,k 1) = Fi(wi,k 1) + Hηlgi,k 1 + O( ηlgi,k 1 2), where H is the Hessian at wi,k 1. Therefore, we have E[ δi,k δ 2] = ρ2E Fi(wi,k) Fi(wi,k) Fi(w) Fi(w) 2 ρ2φi,k, (8) where φi,k is the square of the angle between the unit vector in the direction of Fi(wi,k) and Fi(wi,0). The inequality follows from that (1) Fi( ) Fi( ) 2 < 1, and hence we replace δ with a unit vector in corresponding directions multiplied by ρ2 and obtain the upper bound, (2) the norm of difference in unit vectors can be upper bounded by the square of the arc length on a unit circle. When the learning rate ηl and the local model update of one epoch Fi(wi,k) are small, φi,k is also small. Based on the first order Taylor series, i.e., tan x = x + O(x2), we have tan φi,k = Fi(wi,k) Fi(wi,0) 2 Fi(wi,0) 2 + O(φ2 i,k) = Fi(wi,k 1) Hηlgi,k 1 O( ηlgi,k 1 2) Fi(wi,0) 2 Fi(wi,0) 2 + O(φ2 i,k) (a) 1 + 1 K 1 Fi(wi,k 1) Fi(wi,0) 2 Fi(wi,0) 2 + K Hηlgi,k 1 + O( ηlgi,k 1 2) 2 Fi(wi,0) 2 + O(φ2 i,k) (b) 1 + 1 K 1 Fi(wi,k 1) Fi(wi,0) Fi(wi,0) + KL2η2 l , where (a) is from Lemma A.1 with a = 1 K 1 and (b) is due to maximum eigenvalue of H is bounded by L because F function is L-smooth. Unrolling the recursion above, we have Fi(wi,k) Fi(wi,0) 2 Fi(wi,0) 2 + O(φ2 i,k) τ KL2η2 l 2K2L2η2 l . (9) Plugging (9) into (8), we have i [N] E[ δi,k δ 2] 2K2L2η2 l ρ2. This completes the proof. Lemma B.2 (Bounded Ew of Fed SAM). Suppose our functions satisfies Assumptions 1-2. Then, the updates of Fed SAM for any learning rate satisfying ηl 1 10KL have the drift due to wi,k w: i E[ wi,k w 2] 5Kη2 l (2L2ρ2σ2 l + 6K(3σ2 g + 6L2ρ2) + 6K f( w) 2) + 24K3η4 l L4ρ2. Generalized Federated Learning via Sharpness Aware Minimization Proof. Recall that the local update on client i is wi,k = wi,k 1 ηl gi,k 1. Then, E wi,k w 2 = E wi,k 1 w ηl gi,k 1 2 (a) E wi,k 1 w ηl( gi,k 1 fi( wi,k 1) + fi( wi,k 1) fi( w) + fi( w)) f( w) + f( w) 2 (b) 1 + 1 2K 1 E wi,k 1 w 2 + E ηl( gi,k 1 fi( wi,k 1)) 2 + 6KE ηl( fi( wi,k 1) fi( w)) 2 + 6KE ηl( fi( w) f( w)) 2 + 6K ηl f( w) 2 (c) 1 + 1 2K 1 + 2L2η2 l E wi,k 1 w 2 + 2η2 l L2ρ2σ2 l + 12Kη2 l L2E wi,k 1 w 2 + 12KL2η2 l E δi,k 1 δ 2 + 6Kη2 l E fi( w) f( w) 2 + 6K f( w) 2 (d) 1 + 1 2K 1 + 12Kη2 l L2 + 2L2η2 l E wi,k 1 w 2 + 2η2 l L2ρ2σ2 l + 12KL2η2 l E δi,k δ 2 + 6Kη2 l (3σ2 g + 6L2ρ2) + 6K f( w) 2, where (a) follows from the fact that gi,k 1 is an unbiased estimator of fi( wi,k 1) and Lemma A.3; (b) is from Lemma A.2; (c) is from Assumption 3 and Lemma A.2 and (d) is from Lemma A.5. Averaging over the clients i and learning rate satisfies ηl 1 10KL, we have i [N] E wi,k w 2 1 + 1 2K 1 + 12Kη2 l L2 + 2L2η2 l i [N] E wi,k 1 w 2 + 2η2 l L2ρ2σ2 l + 12KL2η2 l 1 N i [N] E δi,k δ 2 + 6Kη2 l (3σ2 g + 6L2ρ2) + 6K f( w) 2 (a) 1 + 1 K 1 i [N] E wi,k 1 w 2 + η2 l L2ρ2σ2 l + 12KL2η2 l 1 N i [N] E δi,k δ 2 + 6Kη2 l (3σ2 g + 6L2ρ2) + 6K f( w) 2 τ [2η2 l L2ρ2σ2 l + 6Kη2 l (3σ2 g + 6L2ρ2) + 6K f( w) 2] + 12KL2η2 l 1 N i [N] E δi,k δ 2 (b) 5Kη2 l (2L2ρ2σ2 l + 6K(3σ2 g + 6L2ρ2) + 6K f( w) 2) + 24K3η4 l L4ρ2, where (a) is due to the fact that ηl 1 10KL and (b) is from Lemma B.1. B.2. Convergence Analysis of Full client participation Fed SAM f( wr), Er[ r + ηl K f( wr)] ηl K 2 f( wr)) 2 + Kηl L2Ew + Kηl L2Eδ ηl 2KN 2 Er i,k fi( wi,k) Generalized Federated Learning via Sharpness Aware Minimization f( wr), Er[ r + ηl K f( wr)] 2 f( wr)) 2 + ηl 2KN 2 Er i,k fi( wr i,k) fi( wr) 2 ηl 2KN 2 Er i,k fi( wr i,k) 2 f( wr)) 2 + ηl i,k Er fi( wr i,k) fi( wr) 2 ηl 2KN 2 Er i,k fi( wr i,k) 2 f( wr)) 2 + ηlβ2 i,k Er wr i,k wr 2 ηl 2KN 2 Er i,k fi( wr i,k) 2 f( wr)) 2 + ηl L2 i,k Er wr i,k wr 2 + ηl L2 i,k Er δr i,k δr 2 ηl 2KN 2 Er i,k fi( wr i,k) 2 f( wr)) 2 + Kηl L2Ew + Kηl L2Eδ ηl 2KN 2 Er i,k fi( wr i,k) where (a) is from that a, b = 1 2( a 2 + b 2 a b 2) with a = ηl K f( wr) and b = ηl N i,k( fi( wr i,k) fi( wr)); (b) is from Lemma A.2; (c) is from Assumption 1 and (d) is from Lemma A.2. Lemma B.4 For the full client participation scheme, we can bound E[ r 2] as follows: Er[ r 2] Kη2 l L2ρ2 N σ2 l + η2 l N 2 i,k fi( wr i,k) Proof. For the full client participation scheme, we have: Er[ r 2] (a) η2 l N 2 Er 2 (b)= η2 l N 2 Er i,k ( gr i,k fi( wr i,k)) 2 + η2 l N 2 Er i,k fi( wr i,k) (c) Kη2 l L2ρ2 N σ2 l + η2 l N 2 i,k fi( wr i,k) where (a) is from Lemma A.2; (b) is from Lemma A.3 and (c) is from Lemma A.4. Lemma B.5 (Descent Lemma). For all r R 1 and i Sr, with the choice of learning rate , the iterates generated by Fed SAM in Algorithm 3 satisfy: Er[f( wr+1)] f( wr) Kηgηl 2 30K2L2η2 l f( wr) 2 + Kηgηl(10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + ηgηl L3ρ2 where the expectation is w.r.t. the stochasticity of the algorithm. Proof. We firstly propose the proof of full client participation scheme. Due to the smoothness in Assumption 1, taking Generalized Federated Learning via Sharpness Aware Minimization expectation of f( wr+1) over the randomness at communication round r, we have: Er[F(wr+1)] = Er[f( wr+1)] f( wr) + Er f( wr), wr+1 wr] + L 2 Er[ wr+1 wr 2] (a)= f( wr) + Er f( wr), r + Kηgηl f( wr) Kηgηl f( wr) + L 2 η2 g Er[ r 2] (b)= f( wr) Kηgηl f( wr) 2 + ηg f( wr), Er[ r + Kηl f( wr)] + L 2 η2 g Er[ r 2] (c) f( wr) Kηgηl 2 f( wr) 2 + Kηgηl L2Ew + Kηgηl L2Eδ ηgηl i,k fi( wr i,k) 2 η2 g Er[ r 2] (d) f( wr) Kηgηl 2 f( wr) 2 + Kηgηl L2Ew + Kηgηl L2Eδ + Kη2 gη2 l L3ρ2 (e) f( wr) Kηgηl 2 30K2L2η2 l f( wr) 2 + Kηgηl(10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + ηgηl L3ρ2 where (a) is from the iterate update given in Algorithm 3; (b) results from the unbiased estimators; (c) is from Lemma B.3; (d) is from Lemma B.4 and due to the fact that ηgηl 1 KL and (e) is from Lemmas B.1 and B.2. Theorem B.6 Let constant local and global learning rates ηl and ηg be chosen as such that ηl 1 10KL, ηgηl 1 KL. Under Assumption 1-2 and with full client participation, the sequence of outputs {wr} generated by Fed SAM satisfies: min r [R] E F(wr) 2 F 0 F CKηgηl + Φ, where Φ = 1 C [10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + ηgηl L3ρ2 N σ2 l ]. If we choose the learning rates ηl = 1 KN and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 r=1 E[ F(wr+1) ] = O FL RKN + σ2 g R + L2σ2 Proof. For full client participation, summing the result of Lemma B.5 for r = [R] and multiplying both sides by 1 CKηgηl R with ( 1 2 30K2L2η2 l ) > C > 0 if ηl < 1 30KL, we have r=1 E[ F(wr+1) 2] = 1 r=1 E[ f( wr+1) 2] f( wr) f( wr+1) C (10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + ηgηl L3ρ2 C (10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + ηgηl L3ρ2 where the second inequality uses f( wr+1) f and f( w0) f( wr). If we choose the learning rates ηl = 1 KN and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 r=1 E[ F(wr+1) ] = O FL RKN + σ2 g R + L2σ2 l R2K + L2σ2 l R3/2 Generalized Federated Learning via Sharpness Aware Minimization Note that the term G2 R is due to the heterogeneity between each client, ( L2 KN )σ2 is due to the local SGD and 1 R3/2 + 1 R3K is due to the local SAM. We can see that L2 R3/2 + L2 R3K only obtains higher order, and hence SAM part does not take large influence of convergence. After omitting the higher order, we have r=1 E[ F(wr+1) ] = O FL RKN + σ2 g R + L2σ2 This completes the proof. B.3. Convergence Analysis of Partial Client Participation Fed SAM Lemma B.7 For the partial client participation, we can bound Er[ r 2]: Er[ r 2] Kη2 l L2ρ2 j=1 fi( wr i,k) j=0 fi( wr i,j) For the partial client participation scheme w/o replacement, we have: Er[ r 2] (a) η2 l S2 Er 2 = η2 l S2 Er i I{i Sr} X (b)= η2 l SN Er j=0 ( gr i,j fi( wr i,j)) 2 + η2 l S2 Er j=0 fi( wr i,j) (c) Kη2 l L2ρ2 S σ2 l + η2 l S2 Er j=0 fi( wr i,j) = Kη2 l L2ρ2 S σ2 l + η2 l NS j=1 fi( wr i,k) 2 + (S 1)η2 l SN 2 j=0 fi( wr i,j) where (a) is from Lemma A.2; (b) is from Lemma A.3 and (c) is from Lemma A.4. Lemma B.8 For E[ P k fi( wi,k) 2], where fi( wi,k)2 for all k [K] and i [N] is chosen according to Fed SAM, we have: k fi( wi,k) 2 30NK2L2η2 l (2L2ρ2σ2 l + 6K(3σ2 g + 6L2ρ2) + 6K f( w) 2) + 144K4L6η4 l ρ2 + 12NK4L2η2 l ρ2 + 3NK2(3σ2 g + 6L2ρ2) + 3NK2 f( w) 2, where the expectation is w.r.t the stochasticity of the algorithm. k fi( wi,k) k fi( wi,k) fi( w) + fi( w) f( w) + f( w) i,k E[ wi,k w 2] + 6KL2 X i,k E[ δi,k δ 2] + 3NK2(3σ2 g + 6L2ρ2) + 3NK2 f( w) 2 (b) 30NK2L2η2 l (2L2ρ2σ2 l + 6K(3σ2 g + 6L2ρ2) + 6K f( w) 2) + 144K4L6η4 l ρ2 + 12NK4L2η2 l ρ2 + 3NK2(3σ2 g + 6L2ρ2) + 3NK2 f( w) 2. where (a) is from Assumption 1, Lemmas A.2 and A.5; (b) is from Lemmas B.1 and B.2. Generalized Federated Learning via Sharpness Aware Minimization Theorem B.9 Let constant local and global learning rates ηl and ηg be chosen as such that ηl 1 10KL, ηgηl 1 KL and the condition ( 1 2 30K2L2η2 l Lηgηl 2S (3K + 180K3L2η2 l )) > 0 holds. Under Assumption 1-3 and with partial client participation, the sequence of outputs {wr} generated by Fed SAM satisfies: min r [R] E F(wr) 2 F 0 F CKηgηl + Φ, where Φ = 1 C [10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + L3ηgηlρ2 2S σ2 + ηgηl S (30KL5η2 l ρ2σ2 l + 180K2L3η2 l σ2 g + 360KL5η2 l ρ2 + 72K3L7η4 l ρ2 + 6K3L3η2 l ρ2 + 6KLσ2 g + 6KL3ρ2)]. If we choose the learning rates ηl = 1 KS and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 R, we have: r=1 E[ F(wr+1) ] = O FL E[ f( wr+1) ] (a) f( wr) Kηgηl 2 f( wr) 2 + Kηgηl L2Ew + Kηgηl L2Eδ ηgηl i,k fi( wr i,k) 2 η2 g Er[ r 2] (b) f( wr) Kηgηl 2 f( wr) 2 + Kηgηl L2Ew + Kηgηl L2Eδ + Kη2 gη2 l L3ρ2 i,k fi( wr i,k) 2 + η2 g LS 2N j=1 fi( wr i,k) 2 + η2 g LS(S 1) j=0 fi( wr i,j) (c) f( wr) Kηgηl 2 f( wr) 2 + Kηgηl L2Ew + Kηgηl L2Eδ + Kη2 gη2 l L3ρ2 2S σ2 l + Lη2 gη2 l 2NS k fi( wr i,k) 2 (d) f( wr) Kηgηl 2 30K2L2η2 l Lηgηl 2S (3K + 180K3L2η2 l ) f( wr) 2 10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + L3ηgηlρ2 + Kη2 gη2 l S (30KL5η2 l ρ2σ2 l + 180K2L3η2 l σ2 g + 360KL5η2 l ρ2 + 72K3L7η4 l ρ2 + 6K3L3η2 l ρ2 + 6KLσ2 g + 6KL3ρ2) (e) f( wr) CKηgηl f( wr) 2 10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + L3ηgηlρ2 + Kη2 gη2 l S (30KL5η2 l ρ2σ2 l + 180K2L3η2 l σ2 g + 360KL5η2 l ρ2 + 72K3L7η4 l ρ2 + 6K3L3η2 l ρ2 + 6KLσ2 g + 6KL3ρ2), where (a) is from Lemma B.5; (b) is from B.4; (c) is based on taking the expectation of r-th round and if the learning rates satisfy that KLηgηl S 1 S ; (d) is from Lemmas B.1, B.2 and B.8 and (e) holds because there exists a constant C > 0 satisfying ( 1 2 30K2L2η2 l Lηgηl 2S (3K + 180K3L2η2 l )) > C > 0. Generalized Federated Learning via Sharpness Aware Minimization Summing the above result for r = [R] and multiplying both sides by 1 CKηgηl R, we have r=1 E[ F(wr+1) ] f( wr) f( wr+1) 10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + L3ηgηlρ2 S (30KL5η2 l ρ2σ2 l + 180K2L3η2 l σ2 g + 360KL5η2 l ρ2 + 72K3L7η4 l ρ2 + 6K3L3η2 l ρ2 + 6KLσ2 g + 6KL3ρ2) 10KL4η2 l ρ2σ2 l + 90K2L2η2 l σ2 g + 180K2L4η2 l ρ2 + 120K4L6η6 l ρ2 + 16K3η4 l L6ρ2 + L3ηgηlρ2 S (30KL5η2 l ρ2σ2 l + 180K2L3η2 l σ2 g + 360KL5η2 l ρ2 + 72K3L7η4 l ρ2 + 6K3L3η2 l ρ2 + 6KLσ2 g + 6KL3ρ2) , where the second inequality uses F = f( w0) f f( wr) f( wr+1). If we choose the learning rates ηl = 1 KS and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 R, we have: r=1 E[ F(wr+1) ] = O FL RKS + σ2 g R + KSσ2 g R3/2 + L2σ2 l R3/2K + L2σ2 l R3/2 + L2σ2 l R5/2 R2 + 1 R4K2 + L2 KS R5/2SK2 + KS R7/2SK2 + If the number of sampling clients are larger than the number of epochs, i.e., S K, and omitting the larger order of each part, we have: r=1 E[ F(wr+1) ] = O FL This completes the proof. C. Generalization Bounds The generalization bound of Fed SAM follows the margin-based generalization bounds in (Neyshabur et al., 2018; Bartlett et al., 2017; Farnia et al., 2018). We consider the margin-based error for analyzing the generalization error in Fed SAM with general neural network as follows: LSAM γ (F(w)) := 1 fi(w + δi, X)[Y ] max j =Y Fi(w + δi, X)[j] γ . (11) Our generalization bound is based on the two following Lemmas in (Chatterji et al., 2019) and (Neyshabur et al., 2018): Lemma C.1 ((Chatterji et al., 2019)). Let F(w) be any predictor function with parameters w and P be a prior distribution on parameters w. Then, for any γ, m, ζ > 0, with probability 1 ζ over training set M of size m, for any parameter w and any perturbation distribution Q over parameters such that Pδ Q[max X |F(w + δ) F(w)| γ 2, we have: LSAM(F(w)) ˆLSAM γ (F(w)) + 2KL(w + δ P) + log m where KL( P) is the KL-divergence. Generalized Federated Learning via Sharpness Aware Minimization Lemma C.2 ((Neyshabur et al., 2018)). Let norm of input X be bounded by A. For any A > 0, let F(w) be a neural network with Re LU activations and depth d with h units per hidden-layer. Then for any w, X X, and any perturbation δ s.t. δj Wj , where δj is the size of layer j, the change in the output of the network can be bounded as follows: F(w + δ, X) F(w, X) 2 e A δj 2 Wj 2 . Lemma C.1 gives a data-independent deterministic bound which depends on the maximum change of the output function over the domain after a perturbation. Lemma C.2 bounds the change in the output a network based on the magnitude of the perturbation. Theorem C.3 Let input X be an n n image whose norm is bounded by A, f(w) be the classification function with d hidden-layer neural network with h units per hidden-layer, and satisfy 1-Lipschitz activation θ(0) = 0. We assume the constant M 1 for each layer Wj satisfies: 1 M Wj where φw := (Qd j=1 Wj )1/d denotes the geometric mean of f(w) s spectral norms across all layers. Then, for any margin value γ, size of local training dataset on each client m, ζ > 0, with probability 1 ζ over the training set, any parameter of SAM local optimizer w = w + δ such that max X Di Fi(w) f( w) γ 8 , we can obtain the following generalization bound: LSAM(F(w)) ˆLSAM γ (F(w + δ)) + O 32Bd2h log(dh)Q(F(w)) + d log Nmd log(M) where Q(F(w)) := Qd j=1 Wj Pd i=1 Wj 2 F Wj and Wj 2 F is the Frobenius norm. Proof. Based on Lemma C.1, we choose the perturbation δj of each layer which is a zero-mean multivariate Gaussian distribution with diagonal covariance matrix, i.e., N(0, λ2 j I) and λj = Wj ϵ W λ, where ϵ W := (Qd j=1 Wj )1/d is the geometric average of spectral norms across all layers. We consider F( W) with weights W. Since (1 + 1 d)d e and 1 e (1 1 d)d 1, for any weight vector of Wj such that | Wj 2 Wj 2| 1 d for every j, we have: (1/e) d d 1 Then, for the jth layer s random perturbation vector δj N(0, λ2 j I), we have the following bound from (Tropp, 2012) with h representing the width of the jth hidden layer: P ϵ W δj Wj > t 2he t2 Based on (Farnia et al., 2018), we now use a union bound over all layers for a maximum union probability of 1/2, which implies the normalized ϵ W δj Wj for each layer can be upper-bounded by λ p 2h log(4hd). Then, for any W satisfying | Wj Wj | 1 d Wj for all layer j s, we obtain the following: F(W + δ, X) F(W, X) e A d Y δj Wj 2 4ed Aϵd 1 W λ p 2h log(4hd) γ where the last inequality is from choosing λ = γ 32ed Aϵd 1 W h log(4hd), where the perturbation satisfies the Lemma C.2. Then, Generalized Federated Learning via Sharpness Aware Minimization we can bound the KL-divergence in Lemma C.1 as follows: KL(w + δ||P) 2λ2 j = 322e2d2A2ϵ2d W h log(4hd) 322e2d2A2 Qd j=1 Wj h log(4hd) Wj = O d2A2h log(hd) Qd j=1 Wj Based on (Farnia et al., 2018), we have the following result given a fixed underlying distribution P and any ζ > 0 with probability 1 ζ for any W: LSAM(F(w)) ˆLSAM γ (F(w + δ)) + O d2A2h log(hd) Qd j=1 Wj Pd j=1 Wj F Now, we use a cover of size O(d log(M)dd) points, and hence it can demonstrate that for a fixed underlying distribution for any ζ > 0, with probability 1 ζ, we have: LSAM(F(w)) ˆLSAM γ (F(w + δ)) + O d2A2h log(hd) Qd j=1 Wj Pd j=1 Wj F Wj + d log dm log(M) To apply the above result to the FL network of N clients, we apply a union bound to have the bound hold simultaneously for the distribution of every client, which proves for every ζ > 0 with probability at least 1 ζ, the average SAM loss of the clients satisfies the following margin-based bound: LSAM(F(w)) ˆLSAM γ (F(w + δ)) + O d2A2h log(hd) Qd j=1 Wj Pd j=1 Wj F Wj + d log d Nm log(M) This completes the proof. D. Convergence Analysis of Mo Fed SAM D.1. Description of Fed SAM Algorithm and Key Lemmas We outline the Mo Fed SAM algorithm in Algorithm 2. In round r, we sample Sr [N] clients with |Sr| = S and then perform the following updates: Starting from the shared global parameters wr i,0 = wr 1, we update the local parameters for k [K]: wr i,k = wr i,k 1 + ρ gr i,k 1 gr i,k 1 vr i,k 1 = β gr i,k 1 + (1 β) r wr i,k = wr i,k 1 ηlvr i,k 1, After K times local epochs, we obtain the following: r i = wr i,K wr. Compute the new global parameters using only updates from the clients i Sr and a global step-size ηg: r+1 = 1 ηl KS wr+1 = wr + ηg r. Generalized Federated Learning via Sharpness Aware Minimization To prove the convergence of Mo Fed SAM, we first propose some lemmas for Mo Fed SAM as follows: Lemma D.1 (Bounded Ew of Mo Fed SAM). Suppose our functions satisfies Assumptions 1-2. Then, for any i [N], k [K] and r [R] the updates of Mo Fed SAM for any learning rate satisfying ηl 1 30βKL have the drift due to wi,k w: i E[ wi,k w 2] 5Kη2 l (2β2L2η2 l ρ2σ2 l + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2) + 28β2K3L4η4 l ρ2. Proof. Recall that the local update on client i is wi,k = wi,k 1 βηlgi,k 1 + (1 β) r. Then, we have: E wi,k w 2 = E wi,k 1 w ηl(β gi,k 1 + (1 β) ) 2 (a) E wi,k 1 w βηl( gi,k 1 fi( wi,k 1) + fi( wi,k 1) fi( w) + fi( w) f( w) + f( w)) + ηl(1 β) 2 (b) 1 + 1 2K 1 + 2β2L2η2 l E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 7K2βη2 l E fi( wi,k 1) fi( w) 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 7Kβ2η2 l f( w) 2 + 7Kη2 l (1 β)2 2 (c) 1 + 1 2K 1 + 2β2L2η2 l + 14Kβ2L2η2 l E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 7K(1 β)2η2 l 2 + 14Kβ2L2η2 l E δi,k 1 δ 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 7β2KE f( w) 2 (d) 1 + 1 2K 1 + 2β2L2η2 l + 14β2KL2η2 l E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 14Kβ2L2η2 l E δi,k δ 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 7K(1 β)2η2 l 2 + 7β2KE f( w) 2 (e) 1 + 1 2K 1 + 2β2L2η2 l + 14β2KL2η2 l E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 14Kβ2L2η2 l E δi,k δ 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2, where (a) follows from the fact that gi,k 1 is an unbiased estimator of fi( wi,k 1) and Lemma A.3; (b) is from Lemmas A.2 and A.5; (c) is from Assumption 3; Lemma A.2; (d) is from Assumption 2 and (e) is due to the fact that f( w) and β < 1 Averaging over the clients i and learning rate satisfies ηl 1 30βKL, we have: Ew 1 + 1 2K 1 + 2β2L2η2 l + 14β2KL2η2 l E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 14Kβ2L2η2 l E δi,k δ 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2 (a) 1 + 1 K 1 i [N] E wi,k 1 w 2 + 2β2L2η2 l ρ2σ2 l + 14Kβ2L2η2 l 1 N i [N] E δi,k δ 2 + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2 τ [2β2L2η2 l ρ2σ2 l + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2] + 14Kβ2L2η2 l 1 N i [N] E δi,k δ 2 (b) 5K(2β2L2η2 l ρ2σ2 l + 7Kβ2η2 l (3σ2 g + 6L2ρ2) + 14K(1 β)2η2 l f( w) 2) + 28β2K3L4η4 l ρ2, where (a) is due to the fact that ηl 1 30βKL and β 1 2 and (b) is from Lemma B.1. Generalized Federated Learning via Sharpness Aware Minimization D.2. Convergence Analysis of Full client participation Mo Fed SAM Lemma D.2 For the full client participation scheme, we can bound E[ r 2] as follows: Er[ r 2] 2β2L2ρ2 KN σ2 l + 2 K2N 2 Er i,k β fi( wr i,k) + (1 + β) r Proof. For the full client participation strategy, we have: Er[ r+1 2] (a) 1 K2N 2η2 l Er i,k βηl gr i,k + (1 β)ηl r i,k ( gr i,k fi( wr i,k)) 2 + 1 K2N 2 Er i,k β fi( wi,k) + (1 β) r KN σ2 l + 1 K2N 2 Er i,k β fi( wr i,k) + (1 β) r KN σ2 l + 2(1 β)2 KN f( wr) 2 + β2 i,k fi( wr i,k) where (a) is from Lemma A.2; (b) is from Lemma A.3 and (c) is from Lemma A.4. Lemma D.3 (Descent Lemma of full client participation Mo Fed SAM). For all r R 1 and i Sr, with the choice of learning rate, the iterates generated by Mo Fed SAM under full client participation in Algorithm 2 satisfy: Er[f( wr+1)] f( wr) Kηgηl 2 20K2L2η2 l B2 f( wr) 2 + Kηgηl(6K2η2 l β4ρ2 + 5K2ηlβ4ρ2σ2 + 20K3η3 l β2G2 + 16K3η4 l β6ρ2 + ηgηlβ3ρ2 where the expectation is w.r.t. the stochasticity of the algorithm. Er[F(wr+1)] = Er[f( wr+1)] f( wr) + Er f( wr), wr+1 wr] + L 2 Er[ wr+1 wr 2] (a)= f( wr) + ηg Er f( wr), r+1 + β f( wr) β f( wr) + L 2 η2 g Er[ r+1 2] (b)= f( wr) βηg f( wr) 2 + ηg f( wr), Er[ r+1 + β f( wr)] + L 2 η2 g Er[ r+1 2], where (a) is from the iterate update given in Algorithm 3 and (b) results from the unbiased estimators. Generalized Federated Learning via Sharpness Aware Minimization For the third term, we bound it as follows: f( wr), Er[ r+1 + β f( wr)] = (1 β) f( wr) 2 + p β f( wr), Er i,k (ηl fi( wr i,k) ηl fi( wr)) 2 1 f( wr) 2 + β 2K2N 2 Er i,k ( fi( wr i,k) fi( wr)) 2 β 2K2N 2 Er i,k fi( wr i,k) 2 1 f( wr) 2 + β 2KN i,k Er fi( wr i,k) fi( wr) 2 β 2K2N 2 Er i,k fi( wr i,k) 2 1 f( wr) 2 + β 2KN i,k Er fi( wr i,k) fi( wr) 2 β 2K2N 2 Er i,k fi( wr i,k) 2 1 f( wr) 2 + βL2 i,k Er wr i,k wr 2 β 2K2N 2 Er i,k fi( wr i,k) 2 1 f( wr) 2 + βL2(Ew + Eδ) β 2K2N 2 Er i,k fi( wr i,k) where (a) is from that r = f( wr) and f( wr) = P i fi( wr); (b), (c) and (e) are from Lemma A.2 and (d) is from Assumption 1. Plugging (13) into (12), we have: Er[f( wr+1)] f( wr) ηg βηg f( wr) 2 + βL2ηg(Ew + Eδ) βηg 2K2N 2 Er i,k fi( wr i,k) 2 + Lη2 g 2 Er[ r+1 2] (a) f( wr) 3βηg 4 2(1 β)2Lηg f( wr) 2 + βL2ηg(Ew + Eδ) + β2L3ρ2η2 g 2KN σ2 l βηg 2K2N 2 Er i,k fi( wr i,k) 2 + Lβ2η2 g 2K2N 2 Er i,k fi( wr i,k) (b) f( wr) βηg KN 70(1 β)K2L2η2 l 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 (c) f( wr) Cβηg f( wr) 2 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 (a) is from Lemma D.2; (b) is from Lemmas B.1, D.1 and due to the fact that ηg 1 βL and (c) is due to the fact that the condition 3 KN 70(1 β)K2L2η2 l > C > 0 and β 1 Theorem D.4 (Convergence of Mo Fed SAM). Let constant local and global learning rates ηl 1 30βKL, ηg 1 βL and 2 and the condition 3 KN 70(1 β)K2L2η2 l > C > 0 holds. Under Assumptions 1-3 and with full client participation, the sequence of outputs {wr} generated by Fed GSAM satisfies: min r [R] E F(wr) 2 f 0 f where Φ = 1 C (20β3L4η2 l ρ2σ2+25β2K2L2η2 l G2+20β2K4L5η4 l ρ2+4β2KL4η2 l ρ2+ βL3ρ2η2 g 2KN ). If we choose the learning rates ηl = O( 1 RKβL), ηg = O( RβL) and the perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 Generalized Federated Learning via Sharpness Aware Minimization we have: 1 R r=1 E[ F(wr+1) ] = O FβL RKN + β2σ2 g R + Lσ2 l R2β + L2 Proof. Summing the result of Lemma D.3 for r = [R] and multiplying both sides by 1 Cβηg R, we have: r=1 E[ F(wr+1) ] F Cβηg R 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 where it is from that F = f( w0) f f( wr) f( wr+1). If we choose the learning rates ηl = O( 1 RKβL), ηg = RβL) and the perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 r=1 E[ F(wr+1) ] = O FβL RKN + β2σ2 g R + L2σ2 l R2K + Lσ2 l R2β + βL2 If we omit the larger order of each part, we have: r=1 E[ F(wr+1) ] = O FβL RKN + β2σ2 g R + Lσ2 l R2β + L2 This completes the proof. D.3. Convergence Analysis of Partial client participation Mo Fed SAM Lemma D.5 For the partial client participation, we can bound Er[ r 2] as follows: Er[ r 2] KL2η2 l ρ2 S σ2 l + η2 l S2 j=0 fi( wr i,j) Er[ r 2] (a) 1 K2S2η2 l Er k βηl gr i,k + (1 β)ηl r = 1 K2S2η2 l Er i I{i Sr} X k βηl gr i,k (1 β) r j=0 ( gr i,j fi( wr i,j)) 2 + 1 K2S2 Er j=0 β fi( wr i,j) + (1 β) r KS σ2 l + 2(1 β)2 KS f( wr) 2 + 2β2 j=0 fi( wi,j) KS σ2 l + 2(1 β)2 KS f( wr) 2 + 2β2 j=0 fi( wr i,j) 2 + 2β2(S 1) j=0 fi( wr i,j) where (a) is from Lemma A.2; (b) is from Lemma A.3 and (c) is from Lemma A.4. Generalized Federated Learning via Sharpness Aware Minimization Lemma D.6 (Descent Lemma of partial client participation Mo Fed SAM). For all r R 1 and i Sr, with the choice of learning rate, the iterates generated by Mo Fed SAM under partial client participation in Algorithm 2 satisfy: Er[f( wr+1)] f( wr) Kηgηl 2 20K2L2η2 l B2 f( wr) 2 + Kηgηl(6K2η2 l β4ρ2 + 5K2ηlβ4ρ2σ2 + 20K3η3 l β2G2 + 16K3η4 l β6ρ2 + ηgηlβ3ρ2 where the expectation is w.r.t. the stochasticity of the algorithm. Er[F(wr+1)] = Er[f( wr+1)] f( wr) + Er f( wr), wr+1 wr] + L 2 Er[ wr+1 wr 2] = f( wr) βηg f( wr) 2 + ηg f( wr), Er[ r+1 + β f( wr)] + L 2 η2 g Er[ r+1 2]. (14) Similar to full client participation strategy, we bound the third term in (14) as follows: f( wr), Er[ r+1 + β f( wr)] 3β 2 1 f( wr) 2 + βL2(Ew + Eδ) β 2K2N 2 Er i,k fi( wr i,k) Plugging (15) into (14), we have: Er[f( wr+1)] f( wr) ηg βηg f( wr) 2 + βL2ηg(Ew + Eδ) βηg 2K2N 2 Er i,k fi( wr i,k) 2 + Lη2 g 2 Er[ r+1 2] (a) f( wr) 3βηg 4 2(1 β)2Lηg f( wr) 2 + βL2ηg(Ew + Eδ) + β2L3ρ2η2 g 2KS σ2 l βηg 2K2N 2 Er i,k β fi( wr i,k) 2 + Lβ2η2 g 2K2SN k fi( wr i,k) 2 + Lβ2(S 1)η2 g K2SN 2 Er i,k fi( wr i,k) (b) f( wr) βηg KN 70(1 β)K2L2η2 l 90βL3ηgη2 l S 3βLηg 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 30NK2L4η2 l ρ2σ2 l + 270NK3L2η2 l σ2 g + 540NK2L4η2 l ρ2 + 72K4L6η4 l ρ2 + 6NK4L2η2 l ρ2 + 4NK2σ2 g + 3NK2L2ρ2 (c) f( wr) Cβηg f( wr) 2 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 30NK2L4η2 l ρ2σ2 l + 270NK3L2η2 l σ2 g + 540NK2L4η2 l ρ2 + 72K4L6η4 l ρ2 + 6NK4L2η2 l ρ2 + 4NK2σ2 g + 3NK2L2ρ2 , (a) is from Lemma D.2; (b) is from Lemmas B.1, D.1 and due to the fact that ηg S 2βL(S 1) and (c) is due to the fact that the condition 3 KN 70(1 β)K2L2η2 l 90βL3ηgη2 l S 3βLηg 2S > C > 0 and β 1 Generalized Federated Learning via Sharpness Aware Minimization Theorem D.7 Let constant local and global learning rates ηl and ηg be chosen as such that ηl 1 30βKL, ηg S 2βL(S 1) and the condition 3 KN 70(1 β)K2L2η2 l 90βL3ηgη2 l S 3βLηg 2S > 0 holds. Under Assumption 1-3 and with partial client participation, the sequence of outputs {wr} generated by Mo Fed SAM satisfies: min r [R] E F(wr) 2 F 0 F where Φ = 1 C [10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 30NK2L4η2 l ρ2σ2 l + 270NK3L2η2 l σ2 g + 540NK2L4η2 l ρ2 + 72K4L6η4 l ρ2 + 6NK4L2η2 l ρ2 + 4NK2σ2 g + 3NK2L2ρ2 ]. If we choose the learning rates ηl = 1 RβL and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 R, we have: r=1 E[ F(wr+1) ] = O FL Proof. Summing the above result for r = [R] and multiplying both sides by 1 Cβηg R, we have r=1 E[ F(wr+1) ] f( wr) f( wr+1) 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 30NK2L4η2 l ρ2σ2 l + 270NK3L2η2 l σ2 g + 540NK2L4η2 l ρ2 + 72K4L6η4 l ρ2 + 6NK4L2η2 l ρ2 + 4NK2σ2 g + 3NK2L2ρ2 F Cβηg R + 1 10β2L4η2 l ρ2σ2 l + 35β2KL2η2 l (3σ2 g + 6L2ρ2) + 28β2K3L6η4 l ρ2 + 2K2L4η2 l ρ2 + βL3η2 gρ2 30NK2L4η2 l ρ2σ2 l + 270NK3L2η2 l σ2 g + 540NK2L4η2 l ρ2 + 72K4L6η4 l ρ2 + 6NK4L2η2 l ρ2 + 4NK2σ2 g + 3NK2L2ρ2 , where the second inequality uses F = f( w0) f f( wr) f( wr+1). If we choose the learning rates ηl = 1 RβL and perturbation amplitude ρ proportional to the learning rate, e.g., ρ = 1 R, we have: r=1 E[ F(wr+1) ] = O βFL RKS + σ2 g R + β KSσ2 g R3/2 + L2σ2 l R2K + Lσ2 l R2β + L2 KSσ2 l R5/2 Sβ4 + K3/2L If the number of sampling clients are larger than the number of epochs, i.e., S K, and omitting the larger order of each part, we have: 1 R r=1 E[ F(wr+1) ] = O βFL This completes the proof. Generalized Federated Learning via Sharpness Aware Minimization Table 3. Datasets and models. Dataset Task Clients Total samples Model EMNIST (Cohen et al., 2017) Handwritten character recognition 100/50 81,425 2-layer CNN + 2-layer FFN CIFAR-10 (Krizhevsky et al., 2009) Image classification 100/50 60,000 Res Net-18 (He et al., 2016) CIFAR-100 (Krizhevsky et al., 2009) Image classification 100/50 60,000 Res Net-18 (He et al., 2016) E. Experimental Setup We ran the experiments on a CPU/GPU cluster, with RTX 2080Ti GPU, and used Py Torch (Paszke et al., 2019) to build and train our models. The description of datasets is introduced in Table 3. E.1. Dataset Description EMNIST (Cohen et al., 2017) is a 62-class image classification dataset. In this paper, we use 20% of the dataset, and we divide this dataset to each client based on Dirichlet allocation of parameter 0.6 over 100 client by default. We train the same CNN as in (Reddi et al., 2020; Dieuleveut et al., 2021), which includes two convolutional layers with 3 3 kernels, max pooling, and dropout, followed by a 128 unit dense layer. CIFAR-10 and CIFAR-100 (Krizhevsky et al., 2009) are labeled subsets of the 80 million images dataset. They both share the same 60,000 input images. CIFAR-100 has a finer labeling, with 100 unique labels, in comparison to CIFAR-10, having 10 unique labels. The Dirichlet allocation of these two datasets are also 0.6. For both of them, we train Res Net-18 (He et al., 2016) architecture. E.2. Hyperparameters For each algorithm and each dataset, the learning rate was set via grid search on the set {10 0.5, 10 1, 10 1.5, 10 2}. Fed CM, Mime Lite and Mo Fed SAM momentum term β was tuned via grid search on {0.01, 0.1, 0.2, 0.5, 1}. The global learning rate ηg = 1, and local learning rate ηl = 0.1 by default. F. Additional Experiments F.1. Training accuracy on different datasets 0 25 50 75 100 125 150 175 200 Communication Round Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (a) EMNIST dataset. 0 100 200 300 400 500 600 700 800 Communication Round Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed CM (b) CIFAR-10 dataset. 0 100 200 300 400 500 600 700 800 Communication Round Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (c) CIFAR-100 dataset. Figure 4. Training accuracy on different datasets. Figure 4 shows that the training accuracy on different datasets. Comparing with the validation accuracy results in Figure 1, the performance divergence is not clear. The reason is because the global model is easy to overfit the training dataset. Since the distribution of validation dataset on each client is different from training datasets, compared benchmarks perform less generalization. This indicates that our proposed algorithms benefits. Although Fed SAM does not show better performance compared to the momentum FL, i.e., Fed CM and Mime Lite, it saves more transmission costs, since it does not need to download r. For example, on CIFAR-100 dataset, Fed CM achieves 85.26% training accuracy with 4.41% deviation of local models, however, it obtains 54.09% validation accuracy with 14.38% deviation. For Mo Fed SAM algorithm, it can achieve 86.02% training accuracy with 3.23% deviation of local models, and 55.13% validation accuracy with 3.25% Generalized Federated Learning via Sharpness Aware Minimization Table 4. Impact of the heterogeneity on EMNIST dataset (IID, Dirichlet 0.6 and Dirichlet 0.3). Algorithm IID Dirichlet 0.6 Dirichlet 0.3 Train Validation Round Train Validation Round Train Validation Round Fed Avg 96.98(0.73) 89.95(1.95) 32 95.07 (0.94) 84.38 (4.03) 43 93.66(1.27) 82.83(4.42) 61 SCAFFOLD 96.04(1.01) 88.79(2.38) 51 93.85 (1.31) 84.09 (4.56) 69 92.85(1.68) 82.01(4.95) 88 Fed Robust 95.63(0.56) 87.67(1.63) 66 93.17 (0.62) 83.70 (3.37) 91 92.10(1.00) 81.80(3.79) 103 Fed CM 97.47(0.87) 91.13(2.07) 18 96.22 (1.16) 84.85 (4.22) 25 94.83(1.29) 83.09(4.58) 47 Mime Lite 97.26(0.85) 91.29(2.11) 16 95.73 (0.49) 84.88 (3.04) 38 94.90(1.33) 83.14(4.55) 46 Fed SAM 97.42(0.49) 90.22(1.50) 22 96.16 (1.14) 84.75 (4.11) 28 94.32(0.91) 82.97(3.56) 53 Mo Fed SAM 97.58(0.51) 91.52(1.53) 13 96.42 (0.42) 85.07 (2.95) 24 94.98(0.95) 83.28(3.59) 41 Table 5. Impact of the heterogeneity on CIFAR-100 dataset (IID, Dirichlet 0.6 and Dirichlet 0.3). Algorithm IID Dirichlet 0.6 Dirichlet 0.3 Train Validation Round Train Validation Round Train Validation Round Fed Avg 84.68 (1.46) 58.97 (3.56) 253 79.57 (1.84) 53.57 (5.40) 302 77.61 (1.99) 51.22 (6.17) 593 SCAFFOLD 83.41 (2.07) 57.16 (4.32) 327 78.49 (2.02) 51.49 (5.87) 551 76.30 (2.67) 48.89 (6.59) - Fed Robust 82.58 (1.35) 55.87 (3.35) 378 76.80 (1.70) 49.06 (4.75) 893 75.26 (1.87) 47.92 (5.86) - Fed CM 87.05 (1.48) 59.64 (3.73) 149 82.46 (2.00) 55.73 (5.11) 189 79.91 (2.02) 52.57 (6.28) 410 Mime Lite 87.42 (1.56) 59.87 (3.67) 143 82.53 (2.08) 55.82 (5.04) 182 79.96 (2.00) 52.60 (6.31) 397 Fed SAM 85.65 (1.27) 59.11 (3.11) 228 81.04 (1.59) 54.69 (4.36) 245 78.05 (1.71) 51.78 (5.43) 561 Mo Fed SAM 87.82 (1.32) 60.02 (3.20) 129 82.62 (1.53) 56.60 (4.42) 124 80.09 (1.77) 52.90 (5.62) 373 deviation of local models. Tables 2, 4 and 5 aim to show the impact of heterogeneous degrees of FL. From these results, we can clearly see that increasing the degree of heterogeneity makes huge degradation of learning performance. However, it does not effect the training accuracy significantly. For example, on CIFAR-10 dataset, Fed SAM obtains 95.42%, 94.20%, and 92.90% training accuracy, when heterogeneity is IID, Dirichlet 0.6 and Dirichlet 0.3, and 87.36%, 82.55% and 79.82% for validation accuracy. More specifically, the influence of heterogeneity for our proposed algorithms are less than compared benchmarks, which is due to the fact that the more generalized global model, the less impact of distribution shift. F.2. Impact of hypeparameters 0.2 0.4 0.6 0.8 1.0 Sampling rate Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (a) Impact of S. 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 Local epochs K Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (b) Impact of K. 0.2 0.4 0.6 0.8 1.0 Validation Accuracy Fed SAM Mo Fed SAM (c) Impact of ρ. 0.0 0.2 0.4 0.6 0.8 1.0 Validation Accuracy Fed CM Mime Lite Mo Fed SAM (d) Impact of β. Figure 5. Impacts of different parameters on EMNIST dataset. Figures 3-6 aim to show the impacts of different hyperparameters, i.e., the number of participated clients S in each communication round, the number of local epochs K, the perturbation control parameter ρ of SAM optimizer, and the momentum parameter β. We can see that increasing S can improve the performance. However, increasing K cannot guarantee increasing the performance. For ρ and β, they depend on the different algorithms and datasets. By grid searching, it is not difficult to find the suitable value to optimize the performance. Generalized Federated Learning via Sharpness Aware Minimization 0.2 0.4 0.6 0.8 1.0 Sampling rate Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (a) Impact of S. 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 Local epochs K Validation Accuracy Fed Avg SCAFFOLD Fed Robust Fed CM Mime Lite Fed SAM Mo Fed SAM (b) Impact of K. 0.2 0.4 0.6 0.8 1.0 Validation Accuracy Fed SAM Mo Fed SAM (c) Impact of ρ. 0.2 0.4 0.6 0.8 1.0 Validation Accuracy Fed SAM Mo Fed SAM (d) Impact of β. Figure 6. Impacts of different parameters on CIFAR-100 dataset.