# distributed_personalized_empirical_risk_minimization__20560572.pdf Distributed Personalized Empirical Risk Minimization Yuyang Deng Pennsylvania State University yzd82@psu.edu Mohammad Mahdi Kamani Wyze Labs mmkamani@alumni.psu.edu Pouria Mahdavinia Pennsylvania State University pxm5426@psu.edu Mehrdad Mahdavi Pennsylvania State University mzm616@psu.edu This paper advocates a new paradigm Personalized Empirical Risk Minimization (PERM) to facilitate learning from heterogeneous data sources without imposing stringent constraints on computational resources shared by participating devices. In PERM, we aim to learn a distinct model for each client by learning who to learn with and personalizing the aggregation of local empirical losses by effectively estimating the statistical discrepancy among data distributions, which entails optimal statistical accuracy for all local distributions and overcomes the data heterogeneity issue. To learn personalized models at scale, we propose a distributed algorithm that replaces the standard model averaging with model shuffling to simultaneously optimize PERM objectives for all devices. This also allows us to learn distinct model architectures (e.g., neural networks with different numbers of parameters) for different clients, thus confining underlying memory and compute resources of individual clients. We rigorously analyze the convergence of the proposed algorithm and conduct experiments that corroborate the effectiveness of the proposed paradigm. 1 Introduction Recently federated learning (FL) has emerged as an alternative paradigm to centralized learning to encourage federated model sharing and create a framework to support edge intelligence by shifting model training and inference from data centers to potentially scattered and perhaps self-interested systems where data is generated [1]. While undoubtedly being a better paradigm than centralized learning, enabling the widespread adoption of FL necessitates foundational advances in the efficient use of statistical and computational resources to encourage a large pool of individuals or corporations to share their private data and resources. Specifically, due to heterogeneity of data and compute resources among participants, it is necessary, if not imperative, to develop distributed algorithms that are i) cognizant of statistical heterogeneity (data-awareness) by designing algorithms that effectively deal with highly heterogeneous data distributions across devices; and ii) confined to learning models that meet available computational resources of participant devices (system-awareness). To mitigate the negative effect of data heterogeneity (non-IIDness), two common approaches are clustering and personalization. The key idea behind the clustering-based methods [2, 3, 4, 5] is to partition the devices into clusters (coalitions) of similar data distributions and then learn a single shared model for all clients within each cluster. While appealing, the partitioning methods are limited to heuristic ideas such as clustering based on the geographical distribution of devices without taking the actual data distributions into account and lack theoretical guarantees or postulate strong assumptions on initial models or data distributions [4, 5]. In personalization-based methods [6, 7, 8, 2, 9, 10, 11], the idea is to learn a distinct personalized model for each device alongside the global model, which 37th Conference on Neural Information Processing Systems (Neur IPS 2023). can be unified as minimizing a bi-level optimization problem [12]. Personalization aims to learn a model that has the generalization capabilities of the global model but can also perform well on the specific data distribution of each participant suffers from a few key limitations. First, as the number of clients grows, while the number of training data increases, the number of parameters to be learned increases which limits to increase in the number of clients beyond a certain point to balance data and overall model complexity tradeoff a phenomenon known as incidental parameters problem [13]. Moreover, since the knowledge transfer among data sources happens through a single global model, it might lead to suboptimal results. To see this, consider an extreme example, where half of the users have identical data distributions, say D, while the other half share a data distribution that is substantially different, say D (e.g. two distributions with same marginal distribution on features but opposite labeling functions). In this case, the global model obtained by naively aggregating local models (e.g., fixed mixture weights) converges to a solution that suffers from low test accuracy on all local distributions which makes it preferable to learn a model for each client solely based on its local data or carefully chosen subset of data sources. Focusing on system heterogeneity, most existing works require learning models of identical architecture to be deployed across the clients and server (model-homogeneity) [14, 15], and mostly focus on reducing number [15] or size [16, 17, 18] of communications or sampling handling chaotic availability of clients [19, 20]. The requirement of the same model makes it infeasible to train large models due to system heterogeneity where client devices have drastically different computational resources. A few recent studies aim to overcome this issue either by leveraging knowledge distillation methods [21, 22, 23, 24] or partial training (PT) strategies via model subsampling (either static [25, 26], random [27], or rolling [28]). However, KD-based methods require having access to a public representative dataset of all local datasets at server and ignore data heterogeneity in the distillation stage to a large extent. The focus of PT training methods is mostly on learning a single server model using heterogeneous resources of devices and does not aim at deploying a model onto each client after the global server model is trained (which is left as a future direction in [28]). The aforementioned issues lead to a fundamental question: What is the best strategy to learn from heterogeneous data sources to achieve optimal accuracy w.r.t. each data source, without imposing stringent constraints on computational resources shared by participating devices? . We answer this question affirmatively, by proposing a new data&system-aware paradigm dubbed Personalized Empirical Risk Minimization (PERM), to facilitate learning from massively fragmented private data under resource constraints. Motivated by generalization bounds in multiple source domain adaptation [29, 30, 31, 32], in PERM we aim to learn a distinct model for each client by personalizing the aggregation of empirical losses of different data sources which enables each client to learn who to learn with using an effective method to empirically estimate the statistical discrepancy between their associated data distributions. We argue that PERM entails optimal statistical accuracy for all local distributions, thus overcoming the data heterogeneity issue. PERM can also be employed in other learning settings with multiple heterogeneous sources of data such as domain adaptation and multitask learning to entail optimal statistical accuracy. While PERM overcomes the data heterogeneity issue, the number of optimization problems (i.e., distinct personalized ERMs) to be solved scales linearly with the number of data sources. To simultaneously optimize all objectives in a scalable and computationally efficient manner, we propose a novel idea that replaces the standard model averaging in distributed learning with model shuffling and establish its convergence rate. This also allows us to learn distinct model architectures (e.g., neural networks with different number of parameters) for different clients, thus confining to underlying memory and compute resources of individual clients, and overcoming the system heterogeneity issue. This addresses an open question in [28] where only a single global model can be trained in a model-heterogeneous setting, while PERM allows deploying distinct models for different clients. We empirically evaluate the performance of PERM, which corroborates the statistical benefits of PERM in comparison to existing methods. 2 Personalized Empirical Risk Minimization In this section, we formally state the problem and introduce PERM as an ideal paradigm for learning from heterogeneous data sources. We assume there are N distributed devices where each holds a distinct data shard Si = {(xi,j, yi,j)}ni j=1 with ni training samples that are realized by a local source distribution Di over instance space Ξ = X Y. The data distributions across the devices are not independently and identically distributed (non-IID or heterogeneous), i.e., D1 = D2 = . . . = DN, and each distribution corresponds to a local generalization error or true risk Li(h) = E(x,y) Di[ℓ(h(x), y)], i = 1, 2, . . . , N on unseen samples for any model h H, where H is the hypothesis set (e.g., a linear model or a deep neural network) and ℓ: Y Y R+ is a given convex or non-convex loss function. We use b Li(h) = (1/ni) P (x,y) Si ℓ(h(x), y)) to denote the local empirical risk or training loss at ith data shard Si with ni samples. We seek to collaboratively learn a model or personalized models that entail a good generalization on all local distributions, i.e. minimizing true risk Li( ), i = 1, . . . , N for all data sources (all-forall [33]). A simple non-personalized solution, particularly in FL, aims to minimize a (weighted) empirical risk over all data shards in a communication-efficient manner [34]: arg minh H XN i=1 p(i) b Li(h) with p N, (WERM) where N = {p RN + | PN i=1 p(i) = 1} denotes the simplex set. It has been shown that a single model learned by WERM, for example by using fixed mixing weights p(i) = ni/n, where n is total number of training samples, or even agnostic to mixture of distributions [35, 36], while yielding a good performance on the combined datasets of all devices, can suffer from a poor generalization error on individual datasets by increasing the diversity among distributions [37, 38, 39, 40]. To overcome this issue, there has been a surge of interest in developing methods that personalize the global model to individual local distributions. These methods can be unified as the following bi-level problem (a similar unification has been made in [12]): arg minh1,h2,...,hm H b Fi(hi h ) subject to h = arg minh H XN j=1 α(j) b Lj(h) (BERM) where denotes the mixing operation to combine local and global models, and b Fi is a modified local loss which is not necessarily same as local risk b Li. By carefully designing the local loss b Fi and mixing operation , we can develop different penalization schemes for FL including existing methods such as linearly interpolating global and local models [11, 2], multi-task learning [10] and metalearning [9] as special cases. For example, BERM reduces to zero-personalization objective WERM when hi h = h , and b Fi = b Li. At the other end of the spectrum lies the zero-collaboration where the ith client trains its own model without any influence from other clients by setting hi h = hi, b Fi = b Li. The personalized model with interpolation of global and local models can be recovered by setting hi h = αhi + (1 α)h , and b Fi = b Li. While more effective than a single global model learned via WERM, personalization methods suffer from three key issues: i) the global model is still obtained by minimizing the average empirical loss which might limit the statistical benefits of collaboration, ii) overall model complexity increases linearly with number of clients, and iii) a same model space is shared across servers and clients. To motivate our proposal, let us consider the empirical loss PN i=1 α(i) b Li(h) in WERM (or the inner level objective in BERM) with fixed mixing weights α N, and denote the optimal solution by bhα. The excess risk of the learned model bhα on ith local distribution Di w.r.t. the optimal local model h i = arg minh H LDi(h) (i.e. all-for-one) can be bounded by (informal) [31] Li bhα Li (h i ) + j=1 α(j)Rj(H) + 2 j=1 α(j)disc H (Dj, Di) + C where Rj(H) is the empirical Rademacher complexity H w.r.t. Sj, and disc H(Di, Dj) is a pseudodistance on the set of probability measures on Ξ to assess the discrepancy between the distributions Di and Dj with respect to the hypothesis class H as defined below [29]: Definition 1. For a model space H and D, D two probability distributions on Ξ = X Y, disc H (D, D ) = sup h H |Eξ D(ℓ(h, ξ)) Eξ D (ℓ(h, ξ ))| Intuitively, the discrepancy between the two distributions is large, if there exists a predictor that performs well on one of them and badly on the other. On the other hand, if all functions in the hypothesis class perform similarly on both, then D and D have low discrepancy. The above metric which is a special case of a popular family of distance measures in probability theory and mathematical statistics known as integral probability metrics (IPMs) [41], can be estimated from finite data by replacing the expected losses with their empirical counterparts (i.e. Li with b Li). From GEN, it can be observed that a mismatch between pairs of distributions limits the benefits of ERM on all distributions. Indeed, the generalization risk w.r.t. Di will significantly increase when the distribution divergence terms disc H(Dj, Di) are large. It leads to an ideal sample complexity 1/ n where n = n1 +n2 +. . .+n N is the total number of samples, which could have been obtained in the IID setting with α(j) = 1/N when the divergence is small as the pairwise discrepancies disappear. Also, we note that even if the global model achieves a small training error over the union of all data (e.g., over parametrized setting) and can entail a good generalization error with respect to average distribution, the divergence term still remains which illustrates the poor performance of the global model on all local distributions Di, i = 1, 2, . . . , N. This implies that even personalization of the global model as in BERM can not entail a good generalization on all local distributions as there is no effective transfer of positive knowledge among data sources in the presence of high data heterogeneity among local distributions (similar impossibility results even under seemingly generous assumptions on how distributions relate have been made in multisource domain adaptation as well [42]). Interestingly the bound suggests that seeking optimal accuracy on all local distributions requires choosing a distinct mixing of local losses for each client i that minimizes the right-hand side of GEN. This indicates that in an ideal setting (i.e. all-for-all), we can achieve the best accuracy for each local distribution Di by personalizing the WERM, i.e., (i) first estimating αi, i = 1, 2, . . . , N for each client individually, then (ii) solving a variant of WERM for each client with obtained mixing parameters: arg min h Hi j=1 αi(j) b Lj(h) for i = 1, 2, . . . , N. (PERM) By doing this each device achieves the optimal local generalization error by learning who to learn with based on the number of samples at each source and the mismatch between its data distribution with other clients. We also note that compared to WERM and BERM, in PERM since we solve a different aggregated empirical loss for each client, we can pick a different model space/model architecture Hi for each client to meet its available computational resources. While this two-stage method is guaranteed to entail optimal test accuracy for all local distributions Di, however, making it scalable requires overcoming two issues. First, estimating the statistical discrepancies between each pair of data sources (i.e., αi, i = 1, . . . , N) is a computing burden as it requires solving O(N 2) difference of (non)-convex functions in a distributed manner and requires enough samples form each source to entail good accuracy on estimating pairwise discrepancies [41]. Second, we need to solve N variants of the optimization problem in PERM, possibly each with a different model space, which is infeasible when the number of devices is huge (e.g., cross-device federated learning). In the next section, we propose a simple yet effective idea to overcome these issues in a computationally efficient manner. 3 PERM at Scale via Model Shuffling In this section, we propose a method to efficiently estimate the empirical discrepancies among data sources followed by a model shuffling idea to simultaneously solve N versions of PERM to learn a personalized model for each client. We first start by proposing a two-stage algorithm: estimating mixing parameters followed by model shuffling. Then, we propose a single loop unified algorithm that enjoys the same computation and communication overhead as BERM (twice communication of Fed Avg). For ease of exposition, we discuss the proposed algorithms by assuming all the clients share the same model architecture and later on discuss the generalization to heterogeneous model spaces. Specifically, we assume that the model space H is a parameterized by a convex set W Rd and use fi(w) := b Li(w) = P (x,y) Si ℓ(w; (x, y)) to denote the empirical loss at ith data shard. 3.1 Warmup: a two-stage algorithm We start by proposing a two-stage method for solving N variants of PERM in parallel. In the first stage, we propose an efficient method to learn the mixing parameters for all clients. Then, in stage two, we propose a model shuffling method to solve all personalized empirical losses in parallel. Stage 1: Mixing parameters estimation. In the first stage we aim to efficiently estimate the pairwise discrepancy among local distributions to construct mixing parameters αi, i = 1, 2, . . . N. From generalization bound GEN and Definition 1, a direct solution to estimate αi is to solve the following convex-nonconcave minimax problem for each client: α i = arg min α N j=1 α(j) max w W |fi(w) fj(w)| + XN j=1 α(j)2/nj (1) where we estimate the true risks in pairwise discrepancy terms with their empirical counterparts and drop the complexity term as it becomes identical for all sources by fixing the hypothesis space H and bounding it with a computable distribution-independent quantity such as VC dimension [43], or it can be controlled by choice of H or through data-dependent regularization. However, solving the above minimax problem itself is already challenging: the inner maximization loop is a nonconcave (or difference of convex) problem, so most of the existing minimax algorithms will fail on this problem. To our best knowledge, the only provable deterministic algorithm is [44], and it is hard to generalize it to stochastic and distributed fashion. Moreover, since we have N clients, we need to solve N variants of (1), which makes designing a scalable algorithm even harder. To overcome aforementioned issues, we make two relaxations to estimate the per client mixing parameters. First, we optimize an upper bound of pairwise empirical discrepancies supw |fi(w) fj(w)| in terms of gradient dissimilarity between local objectives fi(w) fj(w) [45], which quantifies how different the local empirical losses are and widely used in analysis of learning from heterogeneous losses as in FL [46]. Second, given that the discrepancy measure based on the supremum could be excessively pessimistic in real-world scenarios, and drawing inspiration from the concept of average drift at the optimal point as a right metric to measure the effect of data heterogeneity in federated learning [47], we propose to measure discrepancy at the optimal solution obtained by solving WERM, i.e., w := arg minw W(1/N) PN i=1 fi(w). By doing this, the problem reduces to a simple minimization for each client, given the optimal global solution. These two relaxations lead to solving the following tractable optimization problem to decide the per-client mixing parameters: arg min α N gi(w , α) := XN j=1 α(j) fi(w ) fj(w ) 2 + λ XN j=1 α(j)2/nj (2) where we added a regularization parameter λ and used the squared of gradient dissimilarity for computational convenience. Thus, obtaining all N mixing parameters requires solving a single ERM to obtain optimal global solution and N variants of (2). To get the optimal solution in a communication-reduced manner, we adapt the Local SGD algorithm [48] (or Fed Avg [14]) and find the optimal solution in intermittent communication setting [49] where the clients work in parallel and are allowed to make K stochastic updates between two communication rounds for R consecutive rounds. The detailed steps are given in Algorithm A1 in Appendix B for completeness. After obtaining the global model w R we optimize over α in gi(w R, α) using Tα iterations of GD to get ˆαi. Actually, we will show that as long as w R converge to w , ˆαi, i = 1, . . . , N converges to solution of (2) very fast. Our proof idea is based on the following Lipschitzness observation: α gi(w R) α gi(w ) 2 4L2κ2 g 2 fi(w ) fj(w ) 2 + 4L2 w R w 2 w R w 2 where α gi(w) := arg minα N gi(w, α) and κg := nmax/(2λ) is the condition number of gi(w, ) where nmax = maxi [N] ni. The Lipschitz constant mainly depends on gradient dissimilarity at optimum. As w R tends to w , the α gi( ) becomes more Lipschitz continuous, i.e., the coefficient in front of w R w 2 getting smaller, thus leading to more accurate mixing parameters. To establish the convergence, we make the following standard assumptions. Assumption 1 (Smoothness and strong convexity). We assume fi(x) s are L-smooth and µstrongly convex, i.e., x, y : fi(x) fi(y) L x y . x, y : fi(y) fi(x) + fi(x), y x + 1 We denote the condition number by κ = L/µ. Assumption 2 (Bounded variance). The variance of stochastic gradients computed at each local function is bounded, i.e., i [N], w W, E[ fi(w; ξ) fi(w) 2] δ2. Assumption 3 (Bounded domain). The domain W Rd is a bounded convex set, with diameter D under ℓ2 metric, i.e., w, w W, w w D. Definition 2 (Gradient dissimilarity). We define the following quantities to measure the gradient dissimilarity among local functions: ζi,j(w) := fi(w) fj(w) 2 , ζi(w) := 1 j=1 ζi,j(w), ζ := supw W maxi [N] fi(w) (1/N) XN j=1 fj(w) 2. The following theorem gives the convergence rate of estimated discrepancies to optimal counterparts. Theorem 1. Under Assumptions 1-3, if we run Algorithm A1 on F(w) := 1 N PN j=1 fj(w) with γ = Θ log(RK) µRK for R rounds with synchronization gap K, for κg = 1/(λnmin), it holds that E αR i α i 2 O exp Tα + κ2 g ζi(w )L2 D2 An immediate implication of Theorem 1 is that even we solve (2) at w R, the algorithm will eventually converge to optimal solution of (2) at w . The core technique in the proof, as we mentioned, is to show that for a parameter within a small region centered at w , the function α gi(w) becomes more Lipschitz . The rigorous characterization of this property is captured by Lemma 3 in appendix. Stage 2: Scalable personalized optimization with model shuffling. After obtaining the per client mixing parameters, in the second stage we aim at solving N different personalized variants of PERM denoted by Φ( ˆα1, v), Φ( ˆα2, v), . . . Φ( ˆαN, v) to learn local models where min v WΦ( ˆαi, v) := 1 j=1 ˆαi(j)fj(v). (3) Here we devise an iterative algorithm based on distributed SGD with periodic averaging (a.k.a. Local SGD [48]) to solve these N optimization problems in parallel with no extra overhead. The idea is to replace the model averaging in vanilla distributed (Local) SGD with model shuffling. Specifically, as shown in Algorithm 1 the algorithm proceeds for R epochs where each epoch runs for N communication rounds. At the beginning of each epoch r the server generates a random permutation σr over N clients. At each communication round j within the epoch, the server sends the model of client i to client ij = (i + j) mod N in the permutation σr along with αi(ij). After receiving a model from the server, the client updates the received model for K local steps and returns it back to the server. As it can be seen, the updates of each loss Φ( ˆαi, v), i = 1, 2, . . . , N during an epoch is equivalent to sequentially processing individual losses in (3) which can be considered as permutation-based SGD but with the different that each component now is updated for K steps. By interleaving the permutations, we are able to simultaneously optimize all N objectives. We note that the computation and communication complexity of the proposed algorithm is the same as Local SGD with two differences: the model averaging is replaced with model shuffling, and the algorithms run over a permutation of devices. The convergence rate of Local SGD is well-established in literature [50, 51, 52, 53, 54], but here we establish the convergence of permutation-based variant which is interesting by its own. Assumption 4 (Bounded Gradient). The variance of stochastic gradients computed at each local function is bounded, i.e., i [N], supv W fi(v) G. We note that the Assumption 4 can be realized since we work with a bounded domain W. Theorem 2. Let Assumptions 14 hold. Assume α i is the solution of (2). Then if we run Algorithm 1 on the ˆαi obtained from Algorithm A1, then Algorithm 1 with η = Θ log(NKR3) µR will output the solution ˆvi, i [N], such that with probability at least 1 p, the following statement holds: E[Φ(α i , ˆvi) Φ(α i , v (α i ))] O D2L µ2R + L4 + N LG2N log(1/p) + κ2 ΦL O exp Tα + κ2 g ζi(w )L2 κζ2 Algorithm 1: Shuffling Local SGD Input: Clients 1, ..., N, Number of Local Steps K , Number of Epoch R, mixing parameter ˆα1, ..., ˆαN Epoch for r = 0, ..., R 1 do Server generates permutation σr : [N] 7 [N]. parallel for i = 1, ..., N do Client i sets initial model vr,0 i = vr i . for j = 1, ..., N do Set indices ij = σr((i + j) mod N). Server sends vr,j i to Client ij. vr,j+1 i = SGD-Update(vr,j i , η, ij, K, ˆαi). Client i does projection: vr+1 i = PW(vr,N i ). Output: ˆvi = PW(v R i (1/L) vΦ( ˆαi, v R i )), i [N]. SGD-Update(v, η, j, K, α) Initialize v0 = v for t = 0, ..., K 1 do vt = vt 1 ηα(j)N fj(vt 1; ξt 1 j ) Output: v K where κ = L µ , κg = nmax 2λ and κΦ = µ , and the expectation is taken over randomness of Algorithm A1. That is, to guarantee E[Φ(α i , ˆvi) Φ(α i , v i )] ϵ, we choose R = O max n Lδ2 µϵ , κ2 Φκ2 g L3 ζi(w )D2 ϵ o and Tα = O κg log Lκ2 Φ ϵ . The above theorem shows that even though we run the optimization on Φ( ˆαi, v), our obtained model ˆvi will still converge to the optimal solution of Φ(α i , v). The convergence rate is contributed from two parts: convergence of ˆαi (Algorithm A1) and convergence of personalized model ˆvi (Algorithm 1). Notice that, for the convergence rate of ˆvi, we roughly recover the optimal rate of shuffling SGD [55], which is O(1/R2). However, we suffer from a O(δ2/R) term since each client runs vanilla SGD on their local data (the SGD-Update procedure in Algorithm 1). One medication for this variance term could be deploying variance reduction or shuffling data locally at each client before applying SGD. We notice that there is a recent work [56] also considering the client-level shuffling idea, but our work differs from it in two aspects: 1) they work with local SGD type algorithm and the shuffling idea is employed for model averaging within a subset of clients, while in our algorithm, during each local update period, each client runs shuffling SGD directly on other s model 2) from a theoretical perspective, we are mostly interested in investigating whether the algorithm can converge to the true optimal solution of Φ(α i , v) if we only optimize on a surrogate function Φ( ˆαi, v). One drawback of Algorithm 1 is that we have to wait for Algorithm A1 to finish and output ˆαi, so that we can proceed with Algorithm 1. However, if we are not satisfied with the precision of ˆαi, we may not have a chance to go back to refine it. Hence in the next subsection, we propose to interleave Algorithm 1 and Algorithm A1, and introduce a single-loop variant of PERM which will jointly optimize mixture weights and learn personalized models in an interleaving fashion. 3.2 A unified single loop algorithm We now turn to introducing a single-stage algorithm that jointly optimizes αis and vis as depicted in Algorithm 2 by intertwining the two stages in Algorithm A1 and Algorithm 1 in a single unified method. The idea is to learn the global model, which is used to estimate mixing parameters, concurrent to personalized models. At each communication round, the clients compute gradients on the global model, on their data, after the server collects these gradients does a step mini-batch SGD update on the global model, and then updates the mixing parameters. Then we proceed to update the personalized models similar to Algorithm 1. We note that, unlike the two-stage method where the mixing parameters are computed at the final global model, here the mixing parameters are updated adaptively based on intermediate global models. Theorem 3. Let Assumptions 1 to 4 to be satisfied. Assume α i is the solution of (2). Then if we run Algorithm 2 with η = Θ log(NKR3) µR and γ = Θ log(NKR3) µR , it will output the solution ˆvi, Algorithm 2: Single Loop PERM Input: Clients 1, ..., N, Number of Local Steps K , Number of Epoch R, Initial mixing parameter α0 1 =, ..., α0 N = α = [1/N, ..., 1/N]. Epoch for r = 0, ..., R 1 do Server generates permutation σr : [N] 7 [N]. parallel for Client i = 1, ..., N do Client i sets initial model vr,0 i = vr i . for j = 1, ..., N do Set indices ij = σr((i + j) mod N). Server sends vr,j i to client ij . vr,j+1 i = SGD-Update(vr,j i , η, ij, K, αr i ). // Personalized model update Client i does projection: vr+1 i = PW(vr,N i ). wr+1 = PW(wr γ 1 N PN i=1 1 M PM j=1 fi(wr, ξr i,j)) // Global model update Compute αr+1 i by running Tα steps GD on gi(wr+1, α) // α update Output: ˆvi = PW(v R i (1/L) vΦ(αR i , v R i )), ˆαi = αR i , i [N]. SGD-Update(v, η, j, K, α) Initialize v0 j = v for t = 0, ..., K 1 do vt = vt 1 ηα(j)N fj(vt 1 j ; ξt 1 j ) Output: v K i [N], such that with probability at least 1 p, the following statement holds: E[Φ(α i , ˆvi) Φ(α i , v i )] O LD2 G2N log(1/p) + Lδ2 κ2κ2 g L2 ζi(w )DG R + R2 exp Tα + Lκ2κ2 g ζi(w )δ2 where κ = L µ , κg = nmax µ and the expectation is taken over the randomness of stochastic samples in Algorithm 2. That is, to guarantee E[Φ(α i , ˆvi) Φ(α i , v i )] ϵ, we choose M = O L2κ2κ2 gκ2 Φ ζi(w )δ2 µ2ϵ , R = O max n Lδ2 µϵ , κ2 Φκ2κ2 g L3 ζi(w )DG ϵ o and Tα = O κg log LR2 Compared to Theorem 2, we achieve a slightly worse rate, since we need a large mini-batch when we update global model w. However, the advantages of the single-loop algorithm are two-fold. First, as we mentioned in the previous subsection, we have the freedom to optimize ˆαi to arbitrary accuracy, while in double loop algorithm (Algorithm A1 + Algorithm 1), once we get ˆαi, we do not have the chance to further refine it. Second, in practice, a single-loop algorithm is often easier to implement and can make better use of caches by operating on data sequentially, leading to improved performance, especially on modern processors with complex memory hierarchies. 3.3 Extension to heterogeneous model setting In the homogeneous model setting, we assumed a shared model space W for clients and the server. However, in real-world FL applications, devices have diverse resources and can only train models that match their capacities. We demonstrate that the PERM paradigm can be extended to support learning in model-heterogeneous settings, where different models with varying capacities are used by the server and clients. Focusing on learning the global optimal model to estimate pairwise statistical discrepancies, we note that by utilizing partial training methods [28], where at each communication round a sub-model with a size proportional to resources of each client is sampled from the server s global model (extracted either random, static, or rolling) and is transmitted to be updated locally. Upon receiving updated sub-models, the server can simply aggregate (average) heterogeneous submodel updates sent from the clients to update the global model. We can consider the complexity of 0 20 40 60 80 100 Number of Communications Personal Validation Accuracy PERM Localized Fed Avg Per Fed Avg p Fed ME (a) Personalized Accuracy 0 20 40 60 80 100 Number of Communications Personal Validation Loss PERM Localized Fed Avg Per Fed Avg p Fed ME (b) Personalized Loss Figure 1: Comparative analysis of personalization methods, including our single-loop PERM algorithm, localized Fed Avg, per Fed Avg, and p Fed ME, with synthetic data. The disparity in personalized accuracy and loss highlights PERM s capability to leverage relevant client correlations. models used by clients when estimating mixing parameters by solving a modified version of (2) as: VC(Hj)/nj + j=1 α(j) fi(mi w ) fj(mj w ) 2 + λ j=1 α(j)2/nj, where we simply upper bounded the Rademacher complexity w.r.t. each data source in (GEN) with VC dimension [57]. Here mi is the masking operator to extract a sub-model of the global model to compute local gradients at client i based on its available resources. By doing so, we can adjust mixing parameters based on the complexity of underlying models, as different sub-models of the global model (i.e., mi w versus mj w ) are used to compute drift between pair of gradients at the optimal solution. With regards to training personalized models with heterogeneous local models, as we solve a distinct aggregated empirical loss for each client by interleaving permutations and shuffling models, we can utilize different model spaces Wi, i = 1, . . . , N for different clients that meet their available resources with aforementioned partial training strategies. 4 Experimental Results In this section we benchmark the effectiveness of PERM on synthetic data with 50 clients, where it notably outshone other renowned methods as evident in Figure 1. Our experiments concluded with the CIFAR10 dataset, employing a 2-layer convolutional neural network, where PERM, despite a warmup phase, demonstrated unmatched convergence performance (Figure 2). Additional experiments are reported in the appendix. Across all datasets, the PERM algorithm consistently showcased its robustness and unmatched efficiency in the realm of personalized federated learning. Experiment on synthetic data. To demonstrate the superior effectiveness of our proposed single-loop PERM algorithm compared to other existing personalization methods, we conducted an experiment using synthetic data generated according to the following specifications. We consider a scenario with a total of N clients, where we draw samples from the distribution N(µ1, Σi) for half of the clients, denoted by i [1, N 2 ], and from N(µ2, Σi) for the remaining clients, denoted by i ( N 2 , N]. Following the approach outlined in [58], we adopt a uniform variance for all samples, with Σk,k = k 1.2. Subsequently, we generate a labeling model using the distribution N(µw, Σw). Given a data sample x Rd, the labels are generated as follows: clients 1, ..., N 2 assign labels based on y = sign(w x), while clients N 2 + 1, ..., N assign labels based on y = sign( w x). For this specific experiment, we set µ1 = 0.2, µ2 = 0.2, and µw = 0.1. The data dimension is d = 60, and there are 2 classes in the output. We have a total of 50 clients, each generating 500 samples following the aforementioned guidelines. We train a logistic regression model on each client s data. To demonstrate the superiority of our PERM algorithm, we conducted a performance comparison against other prominent personalized approaches, including the fined-tuned model of Fed Avg [14] (referred to as localized Fed Avg), per Fed Ag [9], and p Fed ME [7]. The results in Figure 1 highlight PERM s efficient learning of personalized models for individual clients. In contrast, competing methods relying on globally trained models struggle to match PERM s effectiveness in highly heterogeneous scenarios, as seen in personalized accuracy and loss. This showcases PERM s exceptional ability to leverage relevant client learning. 0 50 100 150 200 250 300 Number of Communications Personal Validation Accuracy PERM Localized Fed Avg Per Fed Avg p Fed ME (a) Personalized Accuracy 0 50 100 150 200 250 300 Number of Communications Personal Validation Accuracy PERM Localized Fed Avg Per Fed Avg p Fed ME (b) Personalized Loss Figure 2: Comparative analysis of our single-loop PERM algorithm, localized Fed Avg, PFed Me, and per Fed Ag, on CIFAR10 dataset and a 2-layer CNN model. Each client has access to only 2 classes of data. PERM rapidly catches up after 10 rounds of warmup without personalization involved. Figure 3: Runtime of different algorithms in a limited environment. We compare PERM (single loop), Per Fed Avg, Fed Avg, and p Fed Me. PERM has a minimal overhead over Fed Avg and is comparable to other personalization methods. Experiment on CIFAR10 dataset. We extend our experimentation to the CIFAR10 dataset using a 2-layer convolutional neural network. During this test, 50 clients participate, each limited to data from just 2 classes, resulting in a pronounced heterogeneous data distribution. We benchmark our algorithm against Per Fed Avg, PFed Me, and the localized Fed Avg. As illustrated in Figure 2, PERM demonstrates superior convergence performance compared to other personalized strategies. It s noteworthy that PERM s initial personalized validation is significantly lower than that of approaches like Per Fed Avg and PFed Me. This discrepancy stems from our choice to implement 10 communication rounds as a warm-up phase before initiating personalization, whereas other models embark on personalization right from the outset. Computational overhead. In demonstrating the computational efficiency of the proposed PERM algorithm, we present a comparison of wall-clock time of completing one round of communication of PERM and other methods. Each method undertakes 20 local steps along with their distinct computations for personalization. As depicted in Figure 3, the PERM (single loop) algorithm s runtime is compared against personalization methods such as Per Fed Avg, Fed Avg, and p Fed Me. Remarkably, PERM maintains a notably minimal computational overhead. The run-time is slightly worse due to overhead of estimating mixing parameters. 5 Discussion & Conclusion This paper introduces a new data&system-aware paradigm for learning from multiple heterogeneous data sources to achieve optimal statistical accuracy across all data distributions without imposing stringent constraints on computational resources shared by participating devices. The proposed PERM schema, though simple, provides an efficient solution to enable each client to learn a personalized model by learning who to learn with via personalizing the aggregation of data sources through an efficient empirical statistical discrepancy estimation module. To efficiently solve all aggregated personalized losses, we propose a model shuffling idea to optimize all losses in parallel. PERM can also be employed in other learning settings with multiple sources of data such as domain adaptation and multi-task learning to entail optimal statistical accuracy. We would like to embark on the scalability of PERM. The compute burden on clients and servers is roughly the same as existing methods thanks to shuffling (except for extra overhead due to estimating mixing parameters which is the same as running Fed Avg in a two-stage approach and an extra communication in an interleaved approach). The only hurdle would be the required memory at server to maintain mixing parameters, which scales proportionally to the square of the number of clients, which can be alleviated by clustering devices which we leave as a future work. Acknowledgement This work was partially supported by NSF CAREER Award #2239374 NSF CNS Award #1956276. [1] Peter Kairouz, H. Brendan Mc Mahan, Brendan Avent, Aurélien Bellet, Mehdi Bennis, Arjun Nitin Bhagoji, Keith Bonawitz, Zachary Charles, Graham Cormode, Rachel Cummings, Rafael G. L. D Oliveira, Salim El Rouayheb, David Evans, Josh Gardner, Zachary Garrett, Adrià Gascón, Badih Ghazi, Phillip B. Gibbons, Marco Gruteser, Zaid Harchaoui, Chaoyang He, Lie He, Zhouyuan Huo, Ben Hutchinson, Justin Hsu, Martin Jaggi, Tara Javidi, Gauri Joshi, Mikhail Khodak, Jakub Koneˇcný, Aleksandra Korolova, Farinaz Koushanfar, Sanmi Koyejo, Tancrède Lepoint, Yang Liu, Prateek Mittal, Mehryar Mohri, Richard Nock, Ayfer Özgür, Rasmus Pagh, Mariana Raykova, Hang Qi, Daniel Ramage, Ramesh Raskar, Dawn Song, Weikang Song, Sebastian U. Stich, Ziteng Sun, Ananda Theertha Suresh, Florian Tramèr, Praneeth Vepakomma, Jianyu Wang, Li Xiong, Zheng Xu, Qiang Yang, Felix X. Yu, Han Yu, and Sen Zhao. Advances and open problems in federated learning. Foundations and Trends in Machine Learning, 2021. [2] Yishay Mansour, Mehryar Mohri, Jae Ro, and Ananda Theertha Suresh. Three approaches for personalization with applications to federated learning. ar Xiv preprint ar Xiv:2002.10619, 2020. [3] Chengxi Li, Gang Li, and Pramod K Varshney. Federated learning with soft clustering. IEEE Internet of Things Journal, 9(10):7773 7782, 2021. [4] Avishek Ghosh, Jichan Chung, Dong Yin, and Kannan Ramchandran. An efficient framework for clustered federated learning. Advances in Neural Information Processing Systems, 33:19586 19597, 2020. [5] Jie Ma, Guodong Long, Tianyi Zhou, Jing Jiang, and Chengqi Zhang. On the convergence of clustered federated learning. ar Xiv preprint ar Xiv:2202.06187, 2022. [6] Hubert Eichner, Tomer Koren, Brendan Mcmahan, Nathan Srebro, and Kunal Talwar. Semicyclic stochastic gradient descent. In Proceedings of the 36th International Conference on Machine Learning, PMLR, volume 97, 2019. [7] Canh T Dinh, Nguyen Tran, and Tuan Dung Nguyen. Personalized federated learning with moreau envelopes. Advances in Neural Information Processing Systems, 33, 2020. [8] Yutao Huang, Lingyang Chu, Zirui Zhou, Lanjun Wang, Jiangchuan Liu, Jian Pei, and Yong Zhang. Personalized federated learning: An attentive collaboration approach. ar Xiv preprint ar Xiv:2007.03797, 2020. [9] Alireza Fallah, Aryan Mokhtari, and Asuman Ozdaglar. Personalized federated learning: A meta-learning approach. Advances in Neural Information Processing Systems, 2020. [10] Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, and Ameet S Talwalkar. Federated multi-task learning. In Advances in Neural Information Processing Systems, pages 4424 4434, 2017. [11] Yuyang Deng, Mohammad Mahdi Kamani, and Mehrdad Mahdavi. Adaptive personalized federated learning. ar Xiv preprint ar Xiv:2003.13461, 2020. [12] Filip Hanzely, Boxin Zhao, and Mladen Kolar. Personalized federated learning: A unified framework and universal optimization techniques. ar Xiv preprint ar Xiv:2102.09743, 2021. [13] Tony Lancaster. The incidental parameter problem since 1948. Journal of econometrics, 95(2):391 413, 2000. [14] Brendan Mc Mahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pages 1273 1282. PMLR, 2017. [15] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, and Ananda Theertha Suresh. Scaffold: Stochastic controlled averaging for federated learning. In International Conference on Machine Learning, pages 5132 5143. PMLR, 2020. [16] Jenny Hamer, Mehryar Mohri, and Ananda Theertha Suresh. Fedboost: A communicationefficient algorithm for federated learning. In International Conference on Machine Learning, pages 3973 3983. PMLR, 2020. [17] Farzin Haddadpour, Mohammad Mahdi Kamani, Aryan Mokhtari, and Mehrdad Mahdavi. Federated learning with compression: Unified analysis and sharp guarantees. In International Conference on Artificial Intelligence and Statistics, pages 2350 2358. PMLR, 2021. [18] Yan Sun, Li Shen, Tiansheng Huang, Liang Ding, and Dacheng Tao. Fedspeed: Larger local interval, less communication round, and higher generalization accuracy. In The Eleventh International Conference on Learning Representations, 2023. [19] Haibo Yang, Xin Zhang, Prashant Khanduri, and Jia Liu. Anarchic federated learning. In International Conference on Machine Learning, pages 25331 25363. PMLR, 2022. [20] Shiqiang Wang and Mingyue Ji. A unified analysis of federated learning with arbitrary client participation. Advances in Neural Information Processing Systems, 35:19124 19137, 2022. [21] Chuhan Wu, Fangzhao Wu, Lingjuan Lyu, Yongfeng Huang, and Xing Xie. Communicationefficient federated learning via knowledge distillation. Nature communications, 13(1):1 8, 2022. [22] Chaoyang He, Murali Annavaram, and Salman Avestimehr. Group knowledge transfer: Federated learning of large cnns at the edge. Advances in Neural Information Processing Systems, 33:14068 14080, 2020. [23] Tao Lin, Lingjing Kong, Sebastian U Stich, and Martin Jaggi. Ensemble distillation for robust model fusion in federated learning. Advances in Neural Information Processing Systems, 33:2351 2363, 2020. [24] Sohei Itahara, Takayuki Nishio, Yusuke Koda, Masahiro Morikura, and Koji Yamamoto. Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data. IEEE Transactions on Mobile Computing, 22(1):191 205, 2021. [25] Enmao Diao, Jie Ding, and Vahid Tarokh. Heterofl: Computation and communication efficient federated learning for heterogeneous clients. ar Xiv preprint ar Xiv:2010.01264, 2020. [26] Samuel Horvath, Stefanos Laskaridis, Mario Almeida, Ilias Leontiadis, Stylianos Venieris, and Nicholas Lane. Fjord: Fair and accurate federated learning under heterogeneous targets with ordered dropout. Advances in Neural Information Processing Systems, 34:12876 12889, 2021. [27] Sebastian Caldas, Jakub Koneˇcny, H Brendan Mc Mahan, and Ameet Talwalkar. Expanding the reach of federated learning by reducing client resource requirements. ar Xiv preprint ar Xiv:1812.07210, 2018. [28] Samiul Alam, Luyang Liu, Ming Yan, and Mi Zhang. Fedrolex: Model-heterogeneous federated learning with rolling sub-model extraction. Advances in Neural Information Processing Systems, 35:29677 29690, 2022. [29] Shai Ben-David, John Blitzer, Koby Crammer, Alex Kulesza, Fernando Pereira, and Jennifer Wortman Vaughan. A theory of learning from different domains. Machine learning, 79:151 175, 2010. [30] Yishay Mansour and Mariano Schain. Robust domain adaptation. Annals of Mathematics and Artificial Intelligence, 71(4):365 380, 2014. [31] Nikola Konstantinov and Christoph Lampert. Robust learning from untrusted sources. In International conference on machine learning, pages 3488 3498. PMLR, 2019. [32] Koby Crammer, Michael Kearns, and Jennifer Wortman. Learning from multiple sources. Journal of Machine Learning Research, 9(8), 2008. [33] Mathieu Even, Laurent Massoulié, and Kevin Scaman. On sample optimality in personalized collaborative and federated learning. In Neur IPS 2022-36th Conference on Neural Information Processing System, 2022. [34] Jakub Koneˇcn y, H Brendan Mc Mahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon. Federated learning: Strategies for improving communication efficiency. ar Xiv preprint ar Xiv:1610.05492, 2016. [35] Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. In International Conference on Machine Learning, pages 4615 4625. PMLR, 2019. [36] Yuyang Deng, Mohammad Mahdi Kamani, and Mehrdad Mahdavi. Distributionally robust federated averaging. Advances in neural information processing systems, 33:15111 15122, 2020. [37] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smithy. Feddane: A federated newton-type method. In 2019 53rd Asilomar Conference on Signals, Systems, and Computers, pages 1227 1231. IEEE, 2019. [38] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank J Reddi, Sebastian U Stich, and Ananda Theertha Suresh. Scaffold: Stochastic controlled averaging for on-device federated learning. International Conference on Machine Learning, 119:5132 5143, 2020. [39] Farzin Haddadpour and Mehrdad Mahdavi. On the convergence of local descent methods in federated learning. ar Xiv preprint ar Xiv:1910.14425, 2019. [40] Tao Yu, Eugene Bagdasaryan, and Vitaly Shmatikov. Salvaging federated learning by local adaptation. ar Xiv preprint ar Xiv:2002.04758, 2020. [41] Bharath K Sriperumbudur, Kenji Fukumizu, Arthur Gretton, Bernhard Schölkopf, and Gert RG Lanckriet. On integral probability metrics,\phi-divergences and binary classification. ar Xiv preprint ar Xiv:0901.2698, 2009. [42] Steve Hanneke and Samory Kpotufe. A no-free-lunch theorem for multitask learning. The Annals of Statistics, 50(6):3119 3143, 2022. [43] Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014. [44] Zi Xu, Huiling Zhang, Yang Xu, and Guanghui Lan. A unified single-loop alternating gradient projection algorithm for nonconvex concave and convex nonconcave minimax problems. Mathematical Programming, pages 1 72, 2023. [45] Yatin Dandi, Luis Barba, and Martin Jaggi. Implicit gradient alignment in distributed and federated learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 36, pages 6454 6462, 2022. [46] Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Federated learning with non-iid data. ar Xiv preprint ar Xiv:1806.00582, 2018. [47] Jianyu Wang, Rudrajit Das, Gauri Joshi, Satyen Kale, Zheng Xu, and Tong Zhang. On the unreasonable effectiveness of federated averaging with heterogeneous data. ar Xiv preprint ar Xiv:2206.04723, 2022. [48] Sebastian U Stich. Local sgd converges fast and communicates little. In International Conference on Learning Representations, 2018. [49] Blake E Woodworth, Jialei Wang, Adam Smith, Brendan Mc Mahan, and Nati Srebro. Graph oracle models, lower bounds, and gaps for parallel stochastic optimization. In Advances in neural information processing systems, pages 8496 8506, 2018. [50] Blake E Woodworth, Kumar Kshitij Patel, and Nati Srebro. Minibatch vs local sgd for heterogeneous distributed learning. Advances in Neural Information Processing Systems, 33:6281 6292, 2020. [51] Konstantin Mishchenko, Grigory Malinovsky, Sebastian Stich, and Peter Richtárik. Proxskip: Yes! local gradient steps provably lead to communication acceleration! finally! In International Conference on Machine Learning, pages 15750 15769. PMLR, 2022. [52] Honglin Yuan and Tengyu Ma. Federated accelerated stochastic gradient descent. Advances in Neural Information Processing Systems, 33:5332 5344, 2020. [53] Farzin Haddadpour, Mohammad Mahdi Kamani, Mehrdad Mahdavi, and Viveck Cadambe. Local sgd with periodic averaging: Tighter analysis and adaptive synchronization. In Advances in Neural Information Processing Systems, pages 11080 11092, 2019. [54] Eduard Gorbunov, Filip Hanzely, and Peter Richtárik. Local sgd: Unified theory and new efficient methods. In International Conference on Artificial Intelligence and Statistics, pages 3556 3564. PMLR, 2021. [55] Kwangjun Ahn, Chulhee Yun, and Suvrit Sra. Sgd with shuffling: optimal rates without component convexity and large epoch requirements. Advances in Neural Information Processing Systems, 33:17526 17535, 2020. [56] Yae Jee Cho, Pranay Sharma, Gauri Joshi, Zheng Xu, Satyen Kale, and Tong Zhang. On the convergence of federated averaging with cyclic client participation. ar Xiv preprint ar Xiv:2302.03109, 2023. [57] Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. Foundations of machine learning. MIT press, 2018. [58] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith. Federated optimization in heterogeneous networks. ar Xiv preprint ar Xiv:1812.06127, 2018. [59] Sebastian Caldas, Peter Wu, Tian Li, Jakub Koneˇcn y, H Brendan Mc Mahan, Virginia Smith, and Ameet Talwalkar. Leaf: A benchmark for federated settings. ar Xiv preprint ar Xiv:1812.01097, 2018. [60] Tianyi Lin, Chi Jin, and Michael I Jordan. On gradient descent ascent for nonconvex-concave minimax problems. ar Xiv preprint ar Xiv:1906.00331, 2019. [61] Markus Schneider. Probability inequalities for kernel embeddings in sampling without replacement. In Artificial Intelligence and Statistics, pages 66 74. PMLR, 2016. A Additional Experiments In addition to experiments on synthetic and CIFAR10 datasets reported before, we have also conducted experiments on the EMNIST dataset, highlighting PERM s capability to derive superior personalized models by tapping into inter-client data similarities. Additionally, further insights emerged from our tests on the MNIST dataset, revealing how PERM s learned mixture weights adeptly respond to both homogeneous and highly heterogeneous data scenarios. Experiment on EMNIST dataset In addition to the synthetic and CIFAR10 datasets discussed in the main body, we run experiments on the EMNIST dataset [59], which is naturally distributed in a federated setting. In this case, we chose 50 clients and use a 2-layer MLP model, each with 200 neurons. We compare the PERM algorithm with the localized model in Fed Avg and per Fed Avg [9]. As it can be seen in Figure 4, PERM can learn a better personalized model by attending to each client s data according to the similarity of the data distribution between clients. The learned values of α, in Figure 5, show that the clients are learning from each others data, and not focused on their own data only. This signifies that the distribution of data among clients in this dataset is not highly heterogeneous. Note that, since we are using a subset of clients in the EMNIST dataset for the training (only 50 clients for 100 rounds of communication), the results would be sub-optimal. Nonetheless, the experiments are designed to show the effectiveness of different algorithms. As it can be concluded, in terms of performance, PERM consistently excels beyond its peers, demonstrating exemplary results on various benchmark datasets. 0 20 40 60 80 100 120 Number of Communications Personal Validation Accuracy PERM Localized Fed Avg Per Fed Avg (a) Personalized Accuracy 0 20 40 60 80 100 120 Number of Communications Personal Validation Loss PERM Localized Fed Avg Per Fed Avg (b) Personalized Loss Figure 4: Comparative Analysis of Personalization methods, including our single-loop PERM algorithm, localized Fed Avg, and per Fed Ag, with EMNIST dataset. The disparity in personalized accuracy and loss highlights PERM s capability in leveraging relevant client correlations. 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 Figure 5: The heat map of the learned α values for the PERM algorithms on the EMNIST dataset with a 2-layer MLP model. The weights signify that clients mutually benefiting from one another s data, which also highlight that the distribution of data is not significantly heterogeneous in this dataset. The effectiveness of learned mixture weights To show the effectiveness of the two-stage PERM algorithm, as well as the effects of heterogeneity on the distribution of data among clients on the learned weights α in the algorithm, we run this algorithm on MNIST dataset. We use 50 clients, and the model is an MLP, similar to the EMNIST experiment. In this case, we consider two cases: distributing the data randomly across clients (homogeneous) and only allocating 1 class per client (highly heterogeneous). As it can be seen from Figure 6, when the data distribution is homogeneous the learned values of α as diffused across clients. However, when the data is highly heterogeneous, the learned α values will be highly sparse, indicating that each client is mostly learning from its own data and some other clients with partial distribution similarity. Notably, the matrix predominantly exhibits sparsity, indicating that each client selectively leverages information solely from a subset of other clients. This discernible pattern reinforces the inherent confidence that each client is effectively learning from a limited but strategically chosen group of clients. 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 (a) Homogeneous distribution 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45 48 (b) Highly heterogeneous distribution Figure 6: Comparing the performance of two-stage PERM algorithm in learning α values on heterogeneous and homogeneous data distributions. We use MNIST dataset across 50 clients with homogeneous and heterogeneous distributions. B Proof of Two Stages Algorithm In this section we provide the proof of convergence of two-stage implementation of PERM (computing mixing parameters followed by learning personalized models via model shuffling using permutationbased variant of distributed SGD with periodic communication). B.1 Technical Lemmas Lemma 1. Define v (α) := arg minv W Φ(α, v), and assume Φ(α, ) is µ-strongly convex and vΦ(α, v) is L Lipschitz in α. Then, v ( ) is κ-Lipschitz where κ = L/µ. Proof. The proof is similar to Lin et al s result on minimax objective [60]. First, according to optimality conditions we have: v v (α), 2Φ(α, v (α)) 0, v v (α ), 2Φ(α , v (α )) 0 Substituting v with v (α ) and v (α) in the above first and second inequalities respectively yields: v (α ) v (α), 2Φ(α, v (α)) 0, v (α) v (α ), 2Φ(α , v (α )) 0 Adding up the above two inequalities yields: v (α ) v (α), 2Φ(α, v (α)) 2Φ(α , v (α )) 0, (4) Since Φ(α, ) is µ strongly convex, we have: v (α ) v (α), 2Φ(α, v (α )) 2Φ(α, v (α)) µ v (α ) v (α) 2. (5) Adding up (4) and (5) yields: v (α ) v (α), 2Φ(α, v (α )) 2Φ(α , v (α )) µ v (α ) v (α) 2 Finally, using L smoothness of Φ will conclude the proof: L v (α ) v (α) α α µ v (α ) v (α) 2 κ α α v (α ) v (α) Lemma 2 (Optimality Gap). Let Φ(α, v) be defined in (3). Let ˆv = PW( v 1 L vΦ( ˆα, v)). If we assume each fi is L-smooth, µ-strongly convex and with gradient bounded by G, then the following statement holds true: Φ(α , ˆv) Φ(α , v ) 2L v v ( ˆα) 2 + 2κ2 ΦL + 4NG2 µ , v = arg minv W Φ(α , v). Proof. First we show that vΦ(α, v) is NG Lipschitz in α. To see this: vΦ(α, v) vΦ(α , v) = j=1 αi(j) fj(v) j=1 α i(j) fj(v) NG αi α i . Hence due to Lemma 1, we know v (α) is κΦ := µ Lipschitz. According to property of projection, we have: 0 v ˆv, L(ˆv v) + vΦ( ˆα, v) = v ˆv, L(ˆv v) + vΦ(α , v) | {z } T1 + v ˆv, vΦ( ˆα, v) vΦ(α , v) | {z } T2 For T1, we notice: v ˆv, L(ˆv v) + vΦ(α , v) = L v v, ˆv v + 1 η v ˆv, ˆv v + v ˆv, vΦ(α , v) = L v v, ˆv v L v ˆv 2 + v ˆvi, vΦ(α , v) L( v v 2 + 1 4 ˆv v 2) L v ˆv 2 + v ˆv, vΦ(α , v) | {z } where at last step we used Young s inequality. To bound , we apply the L smoothness and µ strongly convexity of Φ(α, ): v ˆvi, vΦ(α i , vi) = v vi, vΦ(α , v) + v ˆv, vΦ(α , v) Φ(α , v) Φ(α , v) µ 2 v v 2 + Φ(α , v) Φ(α , ˆv) + L Φ(α , v) Φ(α , ˆv) µ 2 v v 2 + L Putting above bound back yields: v ˆv, 1 η (ˆv v) + vΦ(α , v) Φ(α , v) Φ(α , ˆv) + 1 2η v v 2 3L Now we switch to bounding T2. Applying Cauchy-Schwartz yields: v ˆvi, vΦ( ˆαi, vi) vΦ(α i , vi) L 4 v vi 2 + L 4 v ˆvi 2 + 4 L vΦ( ˆαi, vi) vΦ(α i , vi) 2 4 v vi 2 + L 4 v ˆvi 2 + 4NG2 L ˆαi α i 2 where at last step we apply NG smoothness of Φ( , v). Putting pieces together yields: 0 Φ(α i , v) Φ(α i , ˆvi) + L 2 vi v 2 + L 2 v vi 2 + 4NG2 L ˆαi α i 2 Re-arranging terms and setting v = v (α ) = arg minv W Φ(α , v) yields: Φ(α , ˆv) Φ(α , v ) L v v 2 + 4NG2 At last, due to the κΦ-Lipschitzness property of of v ( ) as shown in Lemma 1, it follows that: L v v (α ) 2 L v v ( ˆα) 2 + L v ( ˆα) v (α ) 2 2L v v ( ˆα) 2 + 2κ2 ΦL ˆα α 2, as desired. Algorithm A1: Discrepancy Estimation at Optimum Input: Number of clients N, number of local steps K , number of communications rounds R for r = 0, . . . , R 1 do parallel for client i = 0, ..., N 1 do Client i initializes model wr,0 i = wr i . for t = 0, ..., K 1 do wr,t+1 i = wr,t i γ fi(wr,t i ; ξr,t i ) where ξr,t i is a mini-batch sampled from Si. Client i sends wr,K i to Server. Server computes wr+1 = PW 1 N PN i=1 wr,K i Server broadcasts wr+1 to all clients. Server computes ˆαi, i = 1, 2, . . . , N by running Tα steps of GD on gi(w R, α). Output: ˆα1, . . . , ˆαN . B.2 Proof of Convergence of Theorem 1 In this section we are going to prove the result in Theorem 1. To this end, we need to show that mixing parameters we compute by first learning the global model and then solving the optimization problem in objective (2) (as depicted in Algorithm A1) converges to optimal values. Notice that in Algorithm A1 we do not solve gi(w , α) directly, but optimize gi(w R, α) on α for Tα iterations of GD. Hence, firstly we need to show that optimizing the surrogate function will also guarantee the convergence of output of algorithm bα to α by deriving a property of the objective in (2). Formally the property is captured by the following lemma. Lemma 3. Let g(w, α) := PN j=1 αj fi(w) fj(w) 2 + λ PN j=1 α2 j/nj and α g(w) = arg minα N g(w, α). Let w R be the output of Algorithm A1. Then the following statement holds: α g(w R) α g(w ) κ2 g 2 fi(w ) fj(w ) 2 + 4L2 w R w 2 4L w R w 2 where κg := nmax Proof. Define function W(z, α) = XN j=1 αjzj + λ XN j=1 α2 j/nj (6) Apparently, W(z, α) is linear in z and 2 λ nmax strongly convex in z. Next we show that αW(z, α) is Lipschitz in w. To see this, αW(z, α) αW(z , α) = [z1, ..., z N] [z 1, ..., z N] Then, according to Proposition 1, α W (z) := arg minα N W(z, α) is κg lipschitz in z where κg = nmax 2λ , i.e., α W (z) α W (z ) κg z z . Now, let us consider the objective (2): g(w, α) := XN j=1 αj fi(w) fj(w) 2 + λ XN j=1 α2 j/nj We define α g(w) = arg minα N g(w, α). z R = h fi(w R) f1(w R) 2 , ..., fi(w R) f N(w R) 2i , z = h fi(w ) f1(w ) 2 , ..., fi(w ) f N(w ) 2i . Then we know that α g(w R) α g(w ) 2 = α W (z R) α W (z ) 2 κ2 g z R z 2 (7) fi(w R) fj(w R) 2 fi(w ) fj(w ) 2 2 (8) fi(w R) fj(w R) + fi(w ) fj(w ) fi(w R) fj(w R) fi(w ) + fj(w ) 2 fi(w R) fj(w R) + fi(w ) fj(w ) 24L2 w R w 2 Since fi(w R) fj(w R) fi(w ) fj(w ) + 2L w R w , we can conclude that α g(w R) α g(w ) κ2 g 2 fi(w ) fj(w ) 2 + 4L2 w R w 2 4L w R w 2 . With above lemma, to show the convergence of ˆα to α , we do the following decomposition ˆα α 2 2 ˆα α g(w R) 2 + 2 α g(w R) α g(w ) 2 2(1 µηα)K + 2κ2 g 2 fi(w ) fj(w ) 2 + 4L2 w R w 2 4L w R w 2 . Now it remains to show the convergence of Local SGD last iterate w R to optimal solution w . By convention, we use wt = 1 N PN i=1 wt i to denote the virtual average iterates. Lemma 4 (One iteration analysis of Local SGD). Under the condition of Theorem 1, the following statement holds true for any t [T]: E wr,t+1 w 2 (1 µγ)E wr,t w 2 (2γ 4γ2L)E F(w ) F(wr,t) + (γL + 2γ2L2) 1 wr,t i wr,t 2 + γ2 δ2 Proof. According to updating rule in Algorithm A1, we have the following identity: E wr,t+1 w 2 = E wr,t w 2 2γE i=1 fi(wr,t i ; zr,t i ), wr,t w + i=1 fi(wr,t i ) = E wt w 2 2γ i=1 fi(wr,t i ), wr,t w + i=1 fi(wr,t i ; zr,t i ) For T1, since each fj is L smooth and µ strongly convex, we have: i=1 fi(wr,t i ), wt w + i=1 fi(wt i), wt wt i + wr,t i w + i=1 fi(wt i), wr,t wr,t i + wr,t i w + fi(w ) fi(wr,t) µ wr,t i w 2 + L wr,t i wr,t 2 . Due to Jensen s inequality we know: 1 N PN i=1 µ 2 wt i w 2 µ 2 wt w 2. Hence we know: i=1 fi(wr,t i ), wr,t w + F(w ) F(wr,t) µ wr,t w 2 + L wr,t i wr,t 2 ! For T2, we have: i=1 fi(wr,t i ) i=1 fi(wr,t i ) F(wr,t) + 2E F(wr,t) 2 i=1 E wr,t i wr,t 2 + 4L F(wr,t) F(w ) . Now, plugging T1 and T2 back to (10) yields: E wr,t+1 w 2 (1 µγ)E wr,t w 2 (2γ 4γ2L)E F(w ) F(wr,t) + (γL + 2γ2L2) 1 wr,t i wr,t 2 + γ2 δ2 Lemma 5. [50, Lemma 8] For the iterates {wr,t i } generated in Algorithm A1, the following statement holds true: wr,t i wr,t 2 3Kγ2δ2 + 6K2γ2ζ2. Lemma 6 (Last iterate convergence of Local SGD). Under the conditions of Theorem 1, the following statement holds true for the iterates in Algorithm A1: E w R w 2 (1 µγ)RKE w0 w 2 + 1 µγ (γL + 2γ2L2) 3Kγ2δ2 + 6K2γ2ζ2 + γδ2 Proof. We first unroll the recursion in Lemma 4 from t = K to 0, within one communication round: E wr,K w 2 = (1 µγ)KE wr,0 w 2 t=0 (1 µγ)K t(2γ 4γ2L)E F(w ) F(wr,t) t=0 (1 µγ)K t(γL + 2γ2L2) 1 wr,t i wr,t 2 + t=0 (1 µγ)K tγ2 δ2 Since we choose γ 1 2L, we know PK 1 t=0 (1 µγ)K t(2γ 4γ2L)E (F(w ) F(wr,t)) 0. Plugging in Lemma 5 yields: E w R w 2 = (1 µγ)RKE w0 w 2 + 1 µγ (γL + 2γ2L2) 3Kγ2δ2 + 6K2γ2ζ2 + γδ2 Algorithm A2: Shuffling Local SGD (One Client) Input: Clients 0, ..., N 1, Number of Local Steps K , Number of Epoch R, Mixing parameter ˆα Epoch for r = 0, ..., R 1 do Server generates permutation σr : [N] 7 [N]. Client sets initial model vr,0 = vr. for j = 0, ..., N 1 do Server sends vr,j to Client σr(j). vr,j+1 = SGD-Update(vr,j, η, σr(j), K, ˆα). Client i does projection: vr+1 = PW(vr,N). Output: ˆv = v R. SGD-Update(v, η, j, K, α) Initialize v0 = v for t = 0, ..., K 1 do vt = vt 1 ηα(j)N fj(vt 1; ξt 1) Output v K Plugging in γ = log(RK) µRK gives the convergence rate: E w R w 2 O which concludes the proof. Equipped with above results, we are now ready to provide the convergence of main theorem. Proof of Theorem 1. The proof simply follows from Lemma 3: ˆα α 2 2 ˆα α g(w R) 2 + 2 α g(w R) α g(w ) 2 2(1 µηα)Tα + 8Lκ2 g 2 fi(w ) fj(w ) 2 + 4L2 w R w 2 w R w 2 2(1 µηα)Tα + 8Lκ2 g 2 ζi(w ) + 4NL2 w R w 2 w R w 2 Plugging in the convergence of w R w 2 from Lemma 6, and the stepsize ηα = 1 Lg for α yields: E αR i α i 2 O exp( Tα κg ) + κ2 g ζi(w )L2 D2 B.3 Proof of Convergence of Shuffling Local SGD In this section, we are going to prove the convergence of proposed shuffled variant of Local SGD (Theorem 2). The whole proof framework follows the analysis of vanilla shuffling SGD, but notice that there are two differences. First, in vanilla shuffling SGD, in each epoch, algorithm only updates on each component function fj once, while here we have to take K steps of SGD update on each component function. Second, we are considering a weighted sum objective in contrary to averaged objective in [56], which means we need to rescale the objective when we apply without-replacement concentration inequality. Even though our algorithm solves models for N clients, for the sake of simplicity, throughout the proof we only show the convergence of one client s model. The algorithm from one client point of view is described in Algorithm A2, where we drop the client index for notational convenience. Proposition 1. Assume a sequence {wt}K t=1 is obtained by wt = wt 1 ηαN f(wt 1; ξt 1), t = 1, . . . , K, then we have t =t (I αNηHt ) t =t (I αNηHt ) ηαNδt, 0 t K 1, where δt := f(wt; ξt) f(wt), and by convention, we define Qb j=a Aj = I if a < b. Proof. According to updating rule, we have: wt+1 w0 = wt w0 ηαN f(wt; ξt) = wt w0 ηαN f(wt) ηαNδt = wt w0 ηαN f(w0) ηαN( f(wt) f(w0)) ηαNδt. Since f is L smooth, and according to Mean Value Theorem, there is a matrix Ht satisfying µI Ht LI, such that f(wt) f(w0) = Ht(wt w0). Hence we have: wt+1 w0 = (I ηαNHt) (wt w0) ηαN f(w0) ηαNδt. Unrolling the recursion from t to 0 will conclude the proof. The following lemma establishes the updating rule of models between epochs r and r + 1. For notational convenience, whenever there is no confusion, we drop the superscript r in σr. Lemma 7 (One epoch updating rule). Let vr and vr+1 be two iterates generated by Shuffling Local SGD (Algorithm A2), then the following updating rule holds: j =N (I Qj Hj )(Qj fσ(j)(vr) δj), t =K 1 (I ηˆα(σ(j))NHt ) ηˆα(σ(j))N, t =t (I ˆα(σ(j))NηHt ) ηˆα(σ(j))Nδt σ(j), by convention, we define Qb j=a Aj = I if a < b. Proof. According to Proposition 1, we have vr,j+1 = vr,j t =t (I ˆα(σ(j))NηHt ) ηˆα(σ(j))N fσ(j)(vr,j) t =t (I ˆα(σ(j))NηHt ) ηˆα(σ(j))Nδt σ(j). Plugging our definition of Qj and δj yields: vr,j+1 vr = vr,j vr Qj fσ(j)(vr,j) δj. Following the same reasoning in the proof of Proposition 1 will conclude the proof. Lemma 8 (Summation by parts). Let Aj and Bj be complex valued matrices. Then the following fact holds: j=1 Aj Bj = AN n=1 (An+1 An) Proposition 2 (Spectral bound of polynomial expansion). Given a collection of matrices {At} and {Bt}, such that At LI and Bt LI, the following bound hold: t=l (I a At) I t=l (I a At) t=l (I b Bt) Proof. We start with proving the first statement. Expanding the product yields: t=l (I a At) = I + m=1 ( 1)mam X |S|=m,|S| {l,...,h} Hence we have: t=l (I a At) I m=1 ( 1)mam X |S|=m,|S| {l,...,h} According to the upper bound for binomial coefficients: h l m e(h l) m m , we have: Then we switch to the second one. Using the same expanding product yields: t=l (I a At) t=l (I b Bt) m=1 ( 1)mam X |S|=m,|S| {l,...,h} m=1 ( 1)mbm X |S|=m,|S| {l,...,h} The following concentration result is the key to bound variance during shuffling updating. The original result holds for the average of gradients, and we will later on generalize it to an arbitrary weighted sum of gradients. Lemma 9 ([61, Theorem 2]). Suppose n 2. Let g1, g2, . . . , gn Rd satisfy gj G for all j. Let g = 1 n Pn j=1 gj. Let σ Sn be a uniform random permutation of n elements. Then, for i n, with probability at least 1 p, we have j=1 gσ(j) g Lemma 10 (Concentration of partial sum of gradients). Given a uniformly randomly generated permutation σ, and simplex vector α, if we assume each supv W fj(v) G, then the following statement holds true: j=0 ˆα(σ(j)) fσ(j)(vr) 8n log(1/p) + n N Φ( ˆα, vr) . Proof. The proof works by re-writing weighted sum of vectors to average of the these vectors: j=0 ˆα(σ(j)) fσ(j)(vr) j=0 ˆα(σ(j))N fσ(j)(vr) j=0 ˆα(σ(j))N fσ(j)(vr) n Φ( ˆα, vr) + n Φ( ˆα, vr) 8n log(1/p) + n N Φ( ˆα, vr) . Proposition 3 (Spectral norm bound of Q). Let Qj be defined in (11). Then the following bound for the spectral norm of Qj holds true for all j [N]: Qj ηˆα(σ(j))NK(1 + ηNL)K Proof. The proof can be completed by writing down the definitin of Qj and applying Cauchy Schwartz inequality: t =K 1 (I ηˆα(σ(j))N Ht ) t =K 1 (I ηˆα(σ(j))N Ht ) t =K 1 (1 + ηˆα(σ(j))N L) ηˆα(σ(j))NK(1 + ηNL)K. The last step is due to we choose η such that ηNL 1 The following lemma establishes the bound regarding cumulative update between two epochs, namely, vr+1 vr. In particular, Lemma 11 below shows that: (a) in shuffling Local SGD, our update from vr to vr+1 approximates performing NK times of gradient descent with ˆα(j)N fσ(j)(vr), namely, the bias is controlled, and (b) the update itself is bounded, and can be related to the norm of full gradient. Lemma 11. During the dynamic of Algorithm A2, the following statements hold true with probability at least 1 p: j=1 Qj fσ(j)(vr) ηNK j=1 ˆα(j) fσ(j)(vr) 10η2N 2K2 e 4R e 2 Φ( ˆα, vr) 2 + 128η2N 3K2 e 4R e 2 G2 log(1/p). (b) for any N such that 0 N < N j=1 Qj fσ(j)(vr) 3eηNK Φ( ˆα, vr) + G p 8N log(1/p) , t =K 1 (I ηˆα(σ(j))N Ht ) ηˆα(σ(j))N. (11) Proof. We start with proving statement (a). Let Aj = Qj ˆα(σ(j)) and Bj = ˆα(σ(j)) fσ(j)(vr), applying the identity of summation by parts yields: j=1 Qj fσ(j)(vr) = QN 1 ˆα(σ(N 1)) j=1 ˆα(σ(j)) fσ(j)(vr) Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) j=1 ˆα(σ(j)) fσ(j)(vr) j=1 Qj fσ(j)(v) ηNK j=1 ˆα(j) fj(v) QN 1 ˆα(σ(N 1)) ηNKI N X j=1 ˆα(σ(j)) fσ(j)(vr) Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) j=1 ˆα(σ(j)) fσ(j)(vr) According to Proposition 2, we have: t =K 1 (I ηˆα(σ(j))NHt ) I m ηˆα(σ(j))NL m . Since we choose η 1 4NKRL, we have: t =K 1 (I ηˆα(σ(j))NHt ) I m e 4R e, (12) where we use the fact that PK 2 τ m=1 e 4Rm m PK 2 τ m=1 e 4R m e 4R 1 1 e/4R. Hence we know: QN 1 ˆα(σ(N 1)) ηNKI j=1 ˆα(σ(j)) fσ(j)(vr) t =K 1 (I ηˆα(σ(N 1))NHt ) j=1 ˆα(σ(j)) fσ(j)(vr) t =K 1 (I ηˆα(σ(N 1))NHt ) I j=1 ˆα(σ(j)) fσ(j)(vr) η2N 2K2 e 4R e j=1 ˆα(σ(j)) fσ(j)(vr) Thus we have: T1 η2N 2K2 e 4R e 2 Φ( ˆα, vr) 2 . For T2, we first examine the bound of Qn+1 ˆα(σ(n+1)) Qn ˆα(σ(n)): Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) t =K 1 (I ηˆα(σ(n + 1))NHt ) t =K 1 (I ηˆα(σ(n))NHt ) t =K 1 (I ηˆα(σ(n + 1))NHt ) t =K 1 (I ηˆα(σ(n))NHt ) t =K 1 (I ηˆα(σ(n + 1))NHt ) t =K 1 (I ηˆα(σ(n))NHt ) m ηˆα(σ(n))NL m + m ηˆα(σ(n + 1))NL m! where we evoke Proposition 2 at last step. Given that η 1 4NKRL we have: Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) e 4Rm ˆα(σ(n)) m + e 4Rm ˆα(σ(n + 1))L m ! ηNK ˆα(σ(n))e 4R e + ˆα(σ(n + 1))e where we use the reasoining in (12). Hence for T2: 4R e + ˆα(σ(n + 1))e j=1 ˆα(σ(j)) fσ(j)(vr) n=1 (ˆα(σ(n)) + ˆα(σ(n + 1))) G p 8n log(1/p) + n N Φ( ˆα, vr) ηNK 2e 4R e 8N log(1/p) + Φ( ˆα, vr) . where at last step we evoke Lemma 10. So we can conclude T2 2η2N 2K2 2e 4R e 2 G28N log(1/p) + Φ( ˆα, vr) 2 . Putting the bounds of T1 and T2 together will conclude the proof for (a). Now we switch to proving (b). Once again by the summation of parts identity we have: j=1 Qj fσ(j)(vr) = QN ˆα(σ(N )) j=1 ˆα(σ(j)) fσ(j)(vr) Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) j=1 ˆα(σ(j)) fσ(j)(vr). Taking the norm of both side yields: j=1 Qj fσ(j)(vr) QN ˆα(σ(N )) j=1 ˆα(σ(j)) fσ(j)(vr) Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) j=1 ˆα(σ(j)) fσ(j)(vr) Plugging our developed bound for QN and PN 1 n=1 Qn+1 ˆα(σ(n+1)) Qn ˆα(σ(n)) yields: B QN ˆα(σ(N )) j=1 ˆα(σ(j)) fσ(j)(vr) ηNK(1 + ηNL)K G p 8N log(1/p) + N N Φ( ˆα, vr) . where at last step we evoke Lemma 10. And for C, we use the similar reasoning: Qn+1 ˆα(σ(n + 1)) Qn ˆα(σ(n)) j=1 ˆα(σ(j)) fσ(j)(vr) n=1 ηNK e 4R e (ˆα(σ(n + 1)) + ˆα(σ(n))) G p 8n log(1/p) + n N Φ( ˆα, vr) 2ηNK e 4R e 8N log(1/p) + Φ( ˆα, vr) . Putting these pieces together yields: j=1 Qj fσ(j)(vr) 3eηNK Φ( ˆα, vr) + G p 8N log(1/p) . Lemma 12. During the dynamic of Algorithm A2, the following statements hold true with probability at least 1 p: j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) 18e6η4N 4K4L4 Φ( ˆα, vr) 2 + 8G2N log(1/p) Proof. We first apply Cauchy-Schwartz inequality: j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) 1 + ηNK + η2N 2KL 2N ηNLK(1 + ηNL)KL n=1 ˆα(σ(n + 1)) j=1 Qj fσ(j)(vr) e2ηNKL2 N 1 X n=1 ˆα(σ(n + 1)) j=1 Qj fσ(j)(vr) We proceed by applying the bound from Lemma 11 (b): j=1 Qj fσ(j)(vr) 3eηNK Φ( ˆα, vr) + G p 8N log(1/p) . Therefore, it follows that: j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) e2ηNKL2 N 1 X n=1 ˆα(σ(n + 1)) 3eηNK Φ( ˆα, vr) + G p 8N log(1/p) 3e3η2N 2K2L2 Φ( ˆα, vr) + G p 8N log(1/p) Lemma 13 (Noise bound). During the dynamic of Algorithm A2, the following statement for gradient noises holds true with probability at least 1 p: j =N (I Qj Hj )δj t =t (I ˆα(σ(j))NηHt ) ηˆα(σ(j))Nδt σ(j). Proof. According to triangle and Cauchy-Schwartz inequalities we have: j =N (I Qj Hj )δj j =N (I Qj Hj ) δj 1 + (ηˆα(σ(j))NK(1 + ηNL)K)L N δj 1 + (ηˆα(σ(j))NK(1 + ηNL)K)L N | {z } e ηˆα(σ(j))NK (1 + ηNL)K | {z } e B.4 Proof of Theorem 2 Proof. For notational convenience, let us define j =N (I Qj Hj )Qj fσ(j)(vr), j =N (I Qj Hj )δj. Then we recall the updating rule of v (Lemma 7): vr+1 = PW (vr gr δr) Hence we have: E vr+1 v ( ˆα) 2 = E PW (vr gr δr v ( ˆα)) 2 E vr gr δr v ( ˆα) 2 E vr v ( ˆα) 2 2E gr, vr v ( ˆα) + E gr 2 + E δr 2 E vr v ( ˆα) 2 2E ηNK Φ( ˆα, vr), vr v ( ˆα) 2E gr ηNK Φ( ˆα, vr), vr v ( ˆα) + E gr 2 + E δr 2 . Now, applying strongly convexity of Φ( ˆα, ) and Cauchy-Schwartz inequality yields: E vr+1 v ( ˆα) 2 (1 µηNK)E vr v ( ˆα) 2 ηNKE[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))] 1 µηNK E gr ηNK Φ( ˆα, vr) 2 + µηNKE vr v ( ˆα) 2 + E gr 2 + E δr 2 2µηNK)E vr v ( ˆα) 2 ηNKE[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))] + 1 2µηNK E gr ηNK Φ( ˆα, vr) 2 + 2E gr ηNK Φ( ˆα, vr) 2 + 2E ηNK Φ( ˆα, vr) 2 + E δr 2 . Since Φ( ˆα, ) is L smooth, we have: E Φ( ˆα, vr) 2 2LE[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))]. Therefore, we have: E vr+1 v ( ˆα) 2 (1 1 2µηNK)E vr v ( ˆα) 2 (ηNK 4η2N 2K2L)E[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))] + 1 2µηNK + 2 E gr ηNK Φ( ˆα, vr) 2 + E δr 2 . (14) Now, we examine the term gr ηNK Φ( ˆα, vr) 2. First according to summation by part (Lemma 8) by letting Aj := Qj+1 j =N(I Qj Hj ) and Bj = Qj fσ(j)(vr), we have: j =N (I Qj Hj )Qj fσ(j)(vr) j=1 Aj Bj = j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr). Hence we have: gr ηNK Φ( ˆα, vr) 2 j=1 ˆα(σ(j)) fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) j=1 ˆα(σ(j)) fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) 20η2N 2K2 e 4R e 2 + 36e6η4N 4K4L4 ! Φ( ˆα, vr) 2 + 256η2N 3K2 e 4R e 2 G2 log(1/p) + 244e6η4N 4K4L4G2N log(1/p) 20η2N 2K2 e 4R e 2 + 36e6η4N 4K4L4 ! 2L (Φ( ˆα, vr) Φ( ˆα, v ( ˆα))) 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p), where in (1) we apply Jensen s inequality, in (2) we plug in Lemma 11 (a), and Lemma 12, and in (3) we use the L-smoothness of Φ. Plugging above bound back in (19) yields: E vr+1 v ( ˆα) 2 2µηNK)E vr v ( ˆα) 2 + η2N 2K2e4δ2 ηNK 4η2N 2K2L 1 2µηNK + 2 20η2N 2K2 e 4R e 2 36e6η4N 4K4L4 !! E[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))] + 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p). Since we choose η = 4 log( NKR) µNKR , and large enough epoch number: µ + 1 e, 16 log( NKR), 64κ log( we know that T1 0. We thus have: E vr+1 v ( ˆα) 2 2µηNK)E vr v ( ˆα) 2 + η2N 2K2e4δ2 + 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p) Unrolling the recursion from r = R to 0: E v R v ( ˆα) 2 2µηNK)RE v0 v ( ˆα) 2 + 2 1 2µηNK + 2 488e6η3N 3K3L4 + 512ηN 2K e 4R e G2N log(1/p). Plugging in our choice of η will conclude the proof: E v R v ( ˆα) 2 O E v0 v ( ˆα) 2 µ2R + L4 + N G2N log(1/p) Finally, according to Lemma 2 we can complete the proof: Φ(α i , ˆvi) Φ(α i , v i ) 2L v R i v ( ˆαi) 2 + 2κ2 ΦL + 4NG2 µ2R + L4 + N LG2N log(1/p) + κ2 ΦL O exp Tα + κ2 g ζi(w )L2 D2 where we plug in the convergence result from Theorem 1 at last step. C Proof of Convergence of Single Loop Algorithm In this section, we turn to presenting the proof of single loop PERM algorithm (Algorithm 2) where the learning of mixing parameters and personalized models are coupled. Compared to Algorithm A2, here during the optimization of model, the mixing parameters are also being updated. As a result, we need to decouple the two updates which makes the analysis more involved. We begin with some technical lemmas that support the proof of main result. C.1 Technical Lemmas Proposition 4 (Basic Properties of SGD on Smooth Strongly Convex Function). Let wt to be the tth iterate of minibatch SGD on smooth and strongly convex function F, with minibatch size M and learning rate γ. Also assume the variance is bounded by δ. Then the following statements hold true after T iterations of SGD: E F(w T ) 2 2L (1 µγ)T (F(w0) F(w )) + 2γκδ2 E w T +1 w T 2 2γ2L (1 µγ)T (F(w0) F(w )) + 2γ3κδ2 E w T w 2 2 µ (1 µγ)T (F(w0) F(w )) + 2γ δ2 Lemma 14 (Bounded iterates difference of α). Let {αr i } be iterates generated by Algorithm 2, then under conditions of Theorem 3, the following statement holds: αr i αr 1 i 2 6 1 1 Tα + O κ2 g L2 ζi(w ) γ2L (1 µγ)r (F(w0) F(w )) + γ3κδ2 Proof. Define zr = h fi(wr) f1(wr) 2 , ..., fi(wr) f N(wr) 2i . According to updating rule of α in Algorithm 2 and Lemma 3 we have: αr i αr 1 i 2 3 αr i α gi(wr) 2 + 3 α gi(wr 1) α gi(wr) 2 + 3 α gi(wr 1) αr 1 i 2 6(1 µgηα)Tα + 3 α gi(wr 1) α gi(wr) 2 6(1 µgηα)Tα + 3κ2 g zr 1 zr 2 6(1 µgηα)Tα + 3κ2 g fi(wr) fj(wr) + fi(wr 1) fj(wr 1) 24L2 wr wr 1 2 where the third inequality follows from (8). Since fi(wr) fj(wr) fi(w ) fj(w ) + 2L wr w , we can conclude that αr i αr 1 i 2 6(1 µgηα)Tα 8 fi(w ) fj(w ) 2 + 8L2 wr w 2 + 8L2 wr 1 w 2 wr wr 1 2 Tα + O κ2 g L2 ζi(w ) wr wr 1 2 Tα + O κ2 g L2 ζi(w ) γ2L (1 µγ)r (F(w0) F(w )) + γ3κδ2 where at last step we plug in Proposition 4 (16). Lemma 15 (Convergence of α). Let { ˆαi}N i=1 be the mixing parameters generated by Algorithm 2. Then under the conditions of Theorem 3, the following statement holds: ˆαi α 2 2(1 1 κg )Tα + O κ2 g ζi(w )L2 2 µ (1 µγ)T + 2γ δ2 Proof. We notice the following decomposition: ˆαi α 2 = αR i α g(w ) 2 2 αR i α g(w R) 2 + 2 α gi(w R) α gi(w ) 2 κg )Tα + O κ2 g ζi(w ) + NL2 w R w 2 4L w R w 2 κg )Tα + O κ2 g ζi(w )L2 2 µ (1 µγ)T + 2γ δ2 where in the second inequality we apply Lemma 3, and in the third inequality we use Proposition 4 (17). C.2 Proof of Theorem 3 Proof. According to Lemma 2, we have: Φ(α i , ˆvi) Φ(α i , v i ) 2L v R i v ( ˆαi) 2 + 2κ2 ΦL + 4NG2 ˆαi α i 2 . We first examine the convergence of v R i v ( ˆαi) 2. Applying Cauchy-Schwartz inequality yields: vr+1 v (αr+1) 2 1 + 1 4a 2 vr+1 v (αr) 2 + (1 + 4a 2) v (αr+1) v (αr) 2 vr+1 v (αr) 2 + (1 + 4a 2) κ2 Φ αr+1 αr 2 where a = 1 µηNK , and last step is due to that v (α) is κΦ := µ Lipschitz, as proven in Lemma 2 . Similar to the proof of Theorem 2, we first define j =N 1 (I Qj Hj )Qj fσ(j)(vr), j =N 1 (I Qj Hj )δj. Then we recall the updating rule of v: vr+1 = PW (vr gr δr) . Hence we have: E vr+1 v (αr) 2 = E PW (vr gr δr v (αr)) 2 E vr gr δr v (αr) 2 E vr v (αr) 2 2E gr, vr v (αr) + E gr 2 + E δr 2 E vr v (αr) 2 2E ηNK Φ(αr, vr), vr v (αr) 2E gr ηNK Φ(αr, vr), vr v (αr) + E gr 2 + E δr 2 . Now, applying strongly convexity of Φ(αr, ) and Cauchy-Schwartz inequality yields: E vr+1 v (αr) 2 (1 µηNK)E vr v (αr) 2 ηNKE[Φ(αr, vr) Φ(αr, v ( ˆα))] 1 µηNK E gr ηNK Φ( ˆα, vr) 2 + µηNKE vr v ( ˆα) 2 + E gr 2 + E δr 2 2µηNK E vr v (αr) 2 ηNKE[Φ(αr, vr) Φ(αr, v ( ˆα))] + 1 2µηNK E gr ηNK Φ(αr, vr) 2 + 2E gr ηNK Φ(αr, vr) 2 + 2E ηNK Φ(αr, vr) 2 + E δr 2 . where in the first inequality we applied Cauchy-Schwartz inequality and strongly convexity. Since Φ(αr, ) is L smooth, we have: E Φ(αr, vr) 2 2LE[Φ(αr, vr) Φ(αr, v (αr))]. Therefore, it follows that: E vr+1 v ( ˆα) 2 1 1 2µηNK E vr v (αr) 2 (ηNK 4η2N 2K2L)E[Φ(αr, vr) Φ(αr, v (αr))] + 1 2µηNK + 2 E gr ηNK Φ(αr, vr) 2 + E δr 2 (19) Now, we examine the term gr ηNK Φ(αr, vr) 2 in the right hand side of abovee inequality. First according to summation by part (Lemma 8): we let Aj := Qj+1 j =N 1(I Qj Hj ) and Bj = Qj fσ(j)(vr), then we have: j =N (I Qj Hj )Qj fσ(j)(vr) j=1 Aj Bj = j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr). Hence we have: gr ηNK Φ( ˆα, vr) 2 ηNK Φ( ˆα, vr) j =N 1 (I Qj Hj )Qj fσ(j)(vr) j=1 ˆα(σ(j)) fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=0 Qj fσ(j)(vr) j=1 ˆα(σ(j)) fσ(j)(vr) j=1 Qj fσ(j)(vr) j =N (I Qj Hj ) j=1 Qj fσ(j)(vr) 20η2N 2K2 e 4R e 2 + 36e6η4N 4K4L4 ! Φ( ˆα, vr) 2 + 256η2N 3K2 e 4R e 2 G2 log(1/p) + 244e6η4N 4K4L4G2N log(1/p) 20η2N 2K2 e 4R e 2 + 36e6η4N 4K4L4 ! 2L (Φ( ˆα, vr) Φ( ˆα, v ( ˆα))) 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p) where in (1) we apply Jensen s inequality, in (2) we plug in Lemma 11 (a), and Lemma 12, and in (3) we use the L-smoothness of Φ. Plugging above bound back in (19) yields: E vr+1 v (αr) 2 2µηNK)E vr v (αr) 2 + η2N 2K2e4δ2 ηNK 4η2N 2K2L 1 2µηNK + 2 20η2N 2K2 e 4R e 2 36e6η4N 4K4L4 !! E[Φ(αr, vr) Φ(αr, v (αr))] + 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p). Since we choose η = 4 log( NKR) µNKR , and hence we have: E vr+1 v (αr) 2 2µηNK)E vr v (αr) 2 + η2N 2K2e4δ2 1 2ηNK E[Φ( ˆα, vr) Φ( ˆα, v ( ˆα))] | {z } 0 + 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p) 2µηNK E vr v ( ˆα) 2 + η2N 2K2e4δ2 + 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p). Putting above inequality back to (18) yields: vr+1 v (αr+1) 2 1 1 vr v (αr) 2 + 2η2N 2K2e4δ2 + (1 + 4a 2) κ2 Φ αr+1 αr 2 + 2 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p) vr v (αr) 2 + 2η2N 2K2e4δ2 + 2 1 2µηNK + 2 244e6η4N 4K4L4 + 256η2N 3K2 e 4R e G2N log(1/p) Tα + κ2 g L2 ζi(w ) γ2L (1 µγ)r DG + γ3κδ2 where at second inequality we plug in Lemma 14. Unrolling the recursion from r = R to 0, and plugging in η = 4 log(NKR3) µNKR yields: v R v (αR) 2 4µηNK R v0 v (α0) 2 + 1 1 2µηNK + 2 244e6η3N 3K3L4 + 256ηN 2K e 4R e G2N log(1/p) Tα + κ2 g L2 ζi(w ) γ2L (1 µγ)r DG + γ3κδ2 O v0 v (α0) 2 R2 + N µ2R2 G2N log(1/p) + δ2 1 log(NKR3) Tα + κ2 g L2 ζi(w ) γ2L (1 µγ)r DG + γ3κδ2 Plugging in γ = log(NKR3) v R v (αR) 2 O v0 v (α0) 2 R2 + N µ2R2 G2N log(1/p) + δ2 1 log(NKR3) R κ2 g L2 ζi(w )DG 1 log(NKR3) Tα + κ2 g L2 ζi(w )γ2δ2 R2 + N µ2R2 G2N log(1/p) + δ2 κ2 Φκ2κ2 g L2 ζi(w )DG R + κ2 ΦR2 1 1 Tα + κ2 gκ2 ζi(w )δ2 Since ˆvi = v R and ˆαi = αR, we have the convergence of ˆvi v ( ˆαi) 2. Plugging this convergence rate together with the convergence of ˆαi α 2 from Lemma 15: ˆαi α 2 O 2(1 1 κg )Tα + O κ2 g ζi(w )L2 2 µ (1 µγ)R + 2γ δ2 Tα + O κ2 g ζi(w )L2 2 together with applying Lemma 2 leads to: Φ(α i , ˆvi) Φ(α i , v i ) 2L ˆvi v ( ˆαi) 2 + 2κ2 ΦL + 4NG2 G2N log(1/p) + Lδ2 κ2 Φκ2κ2 g L3 ζi(w )DG R + κ2 ΦLR2 1 1 Tα + Lκ2 Φκ2 gκ2 ζi(w )δ2 + 2κ2 ΦL + 4NG2 Tα + κ2 g ζi(w )κL G2N log(1/p) + Lδ2 κ2 Φκ2κ2 g L3 ζi(w )DG R + κ2 ΦLR2 1 1 Tα + L2κ2κ2 gκ2 Φ ζi(w )δ2 thus completing the proof.