# feddar_federated_domainaware_representation_learning__4fdd139f.pdf Published as a conference paper at ICLR 2023 FEDDAR: FEDERATED DOMAIN-AWARE REPRESENTATION LEARNING Aoxiao Zhong1 3 , Hao He2 , Zhaolin Ren1, Na Li1, and Quanzheng Li3 1Harvard University 2Massachusetts Institute of Technology 3Massachusetts General Hospital/Harvard Medical School aoxiaozhong@g.harvard.edu, haohe@mit.edu, zhaolinren@g.harvard.edu nali@seas.harvard.edu, li.quanzheng@mgh.harvard.edu Cross-silo Federated learning (FL) has become a promising tool in machine learning applications for healthcare. It allows hospitals/institutions to train models with sufficient data while the data is kept private. To make sure the FL model is robust when facing heterogeneous data among FL clients, most efforts focus on personalizing models for clients. However, the latent relationships between clients data are ignored. In this work, we focus on a special non-iid FL problem, called Domain-mixed FL, where each client s data distribution is assumed to be a mixture of several predefined domains. Recognizing the diversity of domains and the similarity within domains, we propose a novel method, Fed DAR, which learns a domain shared representation and domain-wise personalized prediction heads in a decoupled manner. For simplified linear regression settings, we have theoretically proved that Fed DAR enjoys a linear convergence rate. For general settings, we have performed intensive empirical studies on both synthetic and real-world medical datasets which demonstrate its superiority over prior FL methods. Our code is available at https://github.com/zlz0414/Fed DAR. 1 INTRODUCTION Federated learning (FL) (Mc Mahan et al., 2017a) is a machine learning approach that allows many clients(e.g. mobile devices or organizations) to collaboratively train a model without sharing the data. It has great potential to resolve the dilemma in real-world machine learning applications, especially in the domain of healthcare. A robust and generalizable model in medical application usually requires a large amount of diverse data to train. However, collecting a large-scale centralized dataset could be expensive or even impractical due to the constraints from regulatory, ethical and legal challenges, data privacy and protection (Rieke et al., 2020). While promising, applying FL to real-world problems has many technical challenges. One eminent challenge is data heterogeneity. Data across the clients are assumed to be independently and identically distributed (iid) by many FL algorithms. But this assumption rarely holds in the real world. It has been shown that non-iid data distributions will cause the failure of standard FL strategies such as Fed Avg (Jiang et al., 2019; Sattler et al., 2020; Kairouz et al., 2019; Li et al., 2020). As an ideal model that can perform well on all clients may not exist, it requires FL algorithms to personalize the model for different data distributions. Prior theoretical work (Marfoq et al., 2021) shows that it is impossible to improve performances on all clients without making assumptions about the client s data distributions. Past works on personalized FL methods (Marfoq et al., 2021; Sattler et al., 2020; Ghosh et al., 2020; Mansour et al., 2020; Deng et al., 2020) make their own assumptions and tailor their methods to those assumptions. In this paper, we propose a new and more realistic assumption where each client s data distribution is a mixture of several predefined domains. We call our problem setting Domain-mixed FL. It is inspired by the fact that the diversity of the medical data can be attributed to some known concept of domains, e.g., different demographic/ethnic groups of patients (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004), different manufacturers or protocols/workflows of image scanners (Mårtensson et al., Equal contribution Published as a conference paper at ICLR 2023 2020; Ciompi et al., 2017) and so on. It is necessary to address the ubiquitous issue of domain shifts among ethic groups (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004) or vendors (Yan et al., 2019; Garrucho et al., 2022; Guan & Liu, 2021) in healthcare data. Despite of the domain shifts, same domain at different clients are usually considered to have the same distribution. The data heterogeneity between FL clients actually comes from the distinct mixtures of diverse domains at clients. These factors motivate us to personalize model for each domain instead of client. Although our method is inspired by healthcare applications where the domain shifts issue is wellknown and domain labels are very basic and accessible, we believe that it can be generally applied to other domains like finance or recommendation systems where users/humans with different demography are involved (Ding et al., 2021; Asuncion & Newman, 2007). However it would require a deep understanding of the data and background knowledge to verify the data distribution assumption as well as the accessibility of the domain label. Fed EM(Marfoq et al., 2021) and Fed Min Max(Papadaki et al., 2021) makes similar assumption on data distribution as ours. However, Fed EM assumes the domains are unknown and tries to learn a linear combination of several shared component models with personalized mixture weights through an EM-like algorithm. Fed Min Max doesn t acknowledge the domain shift between domains and still aims to learn one shared model across domains by adapting minmax optimization to FL setting . Our Contributions. We formulate the proposed problem setting, Domain-mixed FL. Through our analysis, we find prior FL methods, both generic FL methods like Fed Avg (Mc Mahan et al., 2017a), and personalized FL methods like Fed Rep (Collins et al., 2021), are sub-optimal under our setting. To address this issue, we propose a new algorithm, Federated Domain-Aware Representation Learning (Fed DAR). Fed DAR learns a shared model for all the clients but embedded with domainwise personalized modules. The model contains two parts: an shared encoder across all domains and a multi-headed predictor whose heads are associated with domains. For an input from one specific domain, the model extracts representation via the shared encoder and then use the corresponding head to make the prediction. Fed DAR decouples the learning of the encoder and heads by alternating between the updates of the encoder and the heads. It allows the clients to run many local updates on the heads without overfitting on domains with limited data samples. This also leads to faster convergence and better performed model. Fed DAR also adapts different aggregation strategies for the two parts. We use a weighted average operation to aggregate the local updates for the encoder. With additional sample re-weighting, the overall training objective is equally weighted for each domain to encourage the fairness among domains. While for the heads, we propose a novel second-order aggregation algorithm to improve the optimality of aggregated heads. We theoretically show our method enjoys nice properties like linear convergence and small sample complexity in a linear case. Through extensive experiments on both synthetic and real-world datasets, we demonstrate that Fed DAR significantly improves performance over the state-of-the-art personalized FL methods. To the best of our knowledge, our paper is among the first efforts in domain-wise personalized federated learning that achieve such superior performance. 2 RELATED WORK Besides the literature we have discussed above, other works on personalization and fairness in federated learning are also closely related to our work. Personalized Federated Learning. Personalized federated learning has been studied from a variety of perspectives: i) local fine-tuning (Wang et al., 2019; Yu et al., 2020) ii) meta-learning (Chen et al., 2018; Fallah et al., 2020; Jiang et al., 2019; Khodak et al., 2019) iii) local/global model interpolation (Deng et al., 2020; Corinzia et al., 2019; Mansour et al., 2020). iv) clustered FL that partition clients into clusters and learn optimal model for each cluster (Sattler et al., 2020; Mansour et al., 2020; Ghosh et al., 2020)(Zhu et al., 2021). v) Multi-Task Learning(MTL) (Vanhaesebrouck et al., 2017; Smith et al., 2017; Zantedeschi et al., 2020) (Hanzely & Richtárik, 2020; Hanzely et al., 2020; T Dinh et al., 2020; Huang et al., 2021; Li et al., 2021a) vi) local representations or heads for clients (Arivazhagan et al., 2019; Liang et al., 2020; Collins et al., 2021)(Luo et al., 2022). vii) personalized model through hypernetwork or super model (Shamsian et al., 2021; Chen & Chao, 2021; Xu et al., 2022). The personalization module in our approach is similar to vi) and (Zhu et al., 2021) with a multi-branch network. However, the targets we are personalizing the model for are the domains instead of clients. Published as a conference paper at ICLR 2023 Fairness in Federated Learning. There are two commonly used definitions of fairness in existing FL works. One is client fairness, usually formulated as client parity (CP), which requires clients to have similar performance. A few works (Li et al., 2021a; 2019; Mohri et al., 2019; Yue et al., 2021; Zhang et al., 2020) have studied on this. Another is group fairness. In the centralized setting, the fundamental tradeoff between group fairness and accuracy has been studied (Menon & Williamson, 2018; Wick et al., 2019; Zhao & Gordon, 2019), and various fair training algorithms have been proposed(Roh et al., 2020; Jiang & Nachum, 2020; Zafar et al., 2017; Zemel et al., 2013; Hardt et al., 2016). Since the notions of group fairness is the same in FL setting, most of existing FL works adapt methods from centralized setting (Zeng et al., 2021; Du et al., 2021; Gálvez et al., 2021; Chu et al., 2021; Cui et al., 2021). In this work, our method is not designed specifically for certain group fairness notions like demographic parity. Instead, we aim to achieve the best possible performance for each domain through personalization, admitting the difference between data domains. Moreover, our concept of data domains is not limited as demographic groups. It can also be applied to any other mixture of domain data, as long as our assumptions hold. 3 PROBLEM: DOMAIN-MIXED FEDERATED LEARNING Notations. Federated learning involves multiple clients. We denote number of clients as n. We use i [n] {1, 2, ..., n} to index each client. Client i has a local data distribution Di which induces a local learning objective, i.e., the expected risk Ri(f) = E(xi,yi) Di[ℓ(f(xi), yi)], where f : X Y is the model mapping the input x X to the predicted label f(x) Y and ℓ: Y Y R is a generic loss function. In real practice, client i [n] has a finite number, say Li, of data samples, i.e., Si = {(xj i, yj i )}Li j=1. L = Pn i=1 Li denotes the total number of data samples. Problem Formulation of Domain-mixed Federated Learning. We introduce a new formulation of FL problem by assuming each clients local data distribution is a weighted mixture of M domain specific distributions. Specifically, we use { Dm}M m=1 to denote data distributions from M predefined domains. For client i, its local data distribution is Di = P m πi,m Dm where the mixing coefficients πi,m stand for the probabilities of client i s data sample coming from domain m. Take medical application as an example, different hospitals are clients and different ethnic groups are domains. Each ethnic group have different health data while each hospital s data is a mix of ethnic group data. Further, the domains of the data samples are assumed to be known. We use a triplet of variables (x, y, z) to represent the input features, label and domain. The goal of our problem is to learn a model f(x, z) that can perform well in every domain, as shown by the following learning objective, min f R(f) := 1 m=1 Rm(f( , m)) (1) where Rm(f( , m)) = E(x,m) Dm[ℓ(f(x, m), y)]. Our problem focuses on the setting that each domain have a different conditioned label distribution, i.e., Pm(y|x) is different in each domain m. 3.1 COMPARISON WITH PRIOR DOMAIN-UNAWARE FL PROBLEM FORMULATIONS Our FL problem introduces the concept of the domain and focuses on the model s performance in each domain. Many prior FL formulations does not recognize the existence of the domains. For example, the original federated learning algorithms like Fed Avg (Mc Mahan et al., 2017a), Fed Prox (Li et al., 2020) learn a globally shared model that via minimizing the averaged risk, i.e., minf 1 i Ri(f). Some variants consider the fairness across the clients. To do so they optimize the worst client s performance, instead of the averaged performance, i.e., minf maxi Ri(f). Further, personalized FL algorithms, such as Fed Rep (Collins et al., 2021), customize the model s prediction for each client whose objective is minfi:i [n] 1 n Pn i=1 Ri(fi). All the FL algorithms mentioned above will lead sub-optimal solutions to our problem since they do not make domain specific predictions. We illustrate this point by the following toy example of linear regression: We assume the data in m th domain is generated via the following procedure: x Rd is i.i.d sampled from a distribution p(x) with mean zero and covariance Id. The label y R obeys y = x B w m where B Rd k is ground truth linear embedding shared by all domains, and w m Rk is the linear head specific to domain m. Under this setting, Dm stands for data (x, y) Published as a conference paper at ICLR 2023 Algorithm 1 FEDDAR Input: Data S1:n; number of local updates τh for the heads, τϕ for representation; number of communication rounds T; learning rate η. Initialize representation and heads ϕ0, h0 1, ..., h0 M. for t = 1, 2, ..., T do Server sends ϕt 1, ht 1 1 , ..., ht 1 M to the n clients; for client i = 1, 2, ..., n in parallel do Client i initializes ht,0 i,m ht 1 m , m [M]. for s = 1 to τh do ht,s i,m GRD( ˆRi,m(ht,s 1 i,m ϕt 1), ht,s 1 i,m , η), for all m [M]. end for Client i sends updated heads ht,τh i,m and Hessians HRi,m(ht,τh i,m ) to the server. end for Server aggregate the heads for each domain: for m [M] do ht m HEADAGG({ht,τh 1,m, HR1,m(ht,τh 1,m)}n i=1) via Equation 8. end for Server sends ht 1, ..., ht M to the n clients; for client i = 1, 2, ..., n in parallel do for s = 1 to τϕ do ϕt,s i GRD( ˆRi(ϕt,s 1 i , {ht m}M m=1), ϕt,s 1 i , η). end for Client i sends updated representation ϕt i = ϕt,τϕ i to server. end for Server computes the new representation via averaging ϕt Pn i=1 Li L ϕt i. end for where x p(x) and y = x B w m. For each client, the local data Di is a mix of data from different domains with mixed coefficients, i.e., Di = P m πi,m Dm. Fed Avg: learns a single model B and w across the all clients via the following objective, min B,w 1 2n i [n] E(x,y) Di(y x Bw)2 = X i [n],m [M] 2n E(x,y) Dm(y x Bw)2 (2) Fed Rep: learns shared representation B and separated heads wi for each clients i rather than for each domain m, min B,w1,...,wn 1 2n i [n] E(x,y) Di(y x Bwi)2 = X i [n],m [M] 2n E(x,y) Dm(y x Bwi)2 (3) Fed DAR: In contrast, in the linear case, our proposed method, Fed DAR, which will be introduced next, learns a shared representation B and separate heads wm for each domain m, min B,w1, ,wm 1 2M i πi ,m E(x,y) Dm(y x Bwm)2 (4) From the above formulations, we can see that Fed Avd and Fed Rep are not able to achieve the zero error in our domain-mixed FL problem. 4 PROPOSED METHOD: FEDDAR To solve the Domain-mixed FL problem, we propose a new method called, Federated Domain-Aware Representation Learning (Fed DAR). In the following, we first introduce the model, learning objective and the details of the federated optimization algorithm. Published as a conference paper at ICLR 2023 4.1 ALGORITHM OVERVIEW Our model is made of a shared encoder ϕ( ; θ) and M domain specific heads hm( ; wm) whose are parameterized by neural networks with the weights θ and wm, m [M]. According to our problem formation in Equation 1, our algorithm aims to solve the following optimization, min ϕ,h1,...,h M R(ϕ, h1, ..., h M) := 1 m=1 Rm(hm ϕ) (5) We decouple the training between encoder and heads. Specifically, we alternates the learning between the encoder and the heads. The learning is done federatedly and has two conventional steps: (1) local updates; (2) aggregation at the server. Algorithm 1 shows the pseudocode code. Empirical Objectives with Re-weighting. Empirically, the objectives are estimated via the finite data samples at each client. We use Si,m to denote the set of samples from domain m in client i, with Li,m := |Si,m| denoting the sample size. Further, Li := PM m=1 Li,m is the number of samples in client i while Lm := Pn i=1 Li,m is the total number of samples belonging to domain m across all the clients. We denote the empirical risk at client i specific to domain m as ˆRi,m(hm ϕ) := 1 Li,m P (x,y) Si,m ℓ(hm ϕ(x), y). The empirical risk at client i is designed as ˆRi(ϕ, h1, ..., h M) = P Li um ˆRi,m(hm ϕ), where um = L Lm M re-weights the risk for each domain. Combining commonly used weighted average FL objective ˆR(ϕ, h1, ..., h M) = Pn i=1 Li L ˆRi(ϕ, h1, ..., h M), the overall empirical risk is derived as the following, ˆR(ϕ, h1, ..., h M) := L ˆRi(ϕ, h1, ..., h M) = 1 m=1 ˆRm(hm ϕ), (6) where ˆRm(hm ϕ) := Pn i=1 Li,m Lm ˆRi,m(hm ϕ). This is consistent with Equation 5. 4.2 LOCAL UPDATES AT CLIENTS In each communication round, clients use gradient descent methods to optimize representation ϕ( ; θ) and local heads hm( ; wm) for m [M] alternately. We use t to denote the current round. For a module f, f t 1 denotes its optimized version after t 1 rounds. Each round has multiple gradient descent iterations. We use f t,s to denote the module in round t after s iterations. Since the updates are made locally, clients maintain their own copies of both modules, we use subscripts i to index local copy at client i, e.g., f t,s i . We use GRD to denote a generic gradient-base optimization step which takes three inputs: objective function, variables, learning rate and maps them into a new module with updated variables. For example, the vanilla gradient descent has the form GRD(L(fw), fw, η) = fw η w L(fw). For the heads, client i performs τh local gradient-based updates to obtain optimal head given the current shared encoder ϕt 1. For s [τh], client i updates via ht,s i,m GRD( ˆRi,m(ht,s 1 i,m ϕt 1), ht,s 1 i,m , η). For the shared encoder, the clients executes τϕ local updates. Specifically, for s [τϕ], client i updates the local copy of the encoder via ϕt,s i GRD( ˆRi(ϕt,s 1 i , {ht m}M m=1), ϕt,s 1 i , η). The re-weighting mentioned in last section is implemented by re-weighting each sample with um when calculating the loss function. 4.3 AGGREGATION AT SERVER We introduce two strategies: (1) weighted average (WA); (2) second-order aggregation (SA). Weighted average means the aggregated model parameters are the average of the local model s parameters weighted by the number of data samples. Specifically, for the shared encoder, we have θt = Pn i=1 Li L θt 1. Similarly for each head, we have wt m = Pn i=1 Li,m Lm wt 1 m,i . Second-order aggregation is a more complex strategy. Ideally, we want the head aggregation generates the globally optimal model given a set of locally optimal model, as shown in the following, w arg min w J (w) i=1 αi Ji(w), given w i = arg min w Ri(w) i [n]. (7) Published as a conference paper at ICLR 2023 where Ji is i th client s virtual objective, αi := Li/L is the importance of the client, Li is the number of data samples. We call Ji the virtual objective to distinguish it from the real learning objective Ri. The virtual objective is defined as an objective that the local updates give the optimal solution w.r.t it. It is introduced since the local updates during two aggregated are not guaranteed to optimize the head to optimal w.r.t the real objective. For example, if each local update is single step gradient descent with a learning rate η, i.e., wt+1 i = wt η w Ri(wt). Then the virtual objective becomes Ji(w) = Ri(wt)+(w wt) w Ri(wt)+ 1 2η w wt 2 2 which satisfies wt+1 i arg minw Ji(w). Such a virtual objective leads the solution of problem 7 to w = 1 n Pn i=1 w i which is the simple averaging strategy. However, in real practice, the local updates are usually more complicated which makes the virtual objective closer to the true objective. We consider the case that the virtual objective is the secondorder Taylor expansion of the true objective, i.e., J (w) = R(wt) + (w wt) w R(wt) + 1 2(w wt) HR(wt)(w wt) where HR is the Hessian matrix. Then each round of local update equivalent to a Newton-like step, wt+1 i = wt HRi(wt) 1 w Ri(wt). While wt+1 = wt HR(wt) 1 w R(wt) is the desired globally optima. Leveraging the fact that, w R(w) = P i [n] αi w Ri(w) and HR(w) = P i [n] αi HRi(w), we can get wt+1 from wt+1 i via the following equation, which we call second-order aggregation, wt+1 = HR(wt) 1 X i [n] αi HRi(wt)wt+1 i (8) Specifically, to implement second-order aggregation, in each round, the local clients first optimize the model locally for several epochs. Then we compute the Hessian matrices for each local model and send them to the server for aggregation. Note that sending the Hessian takes a communication cost being quadratic to the size of the weight. In real practice, the predictive head is usually small, e.g., a linear layer with hundreds of neurons. Thus it is acceptable to aggregate the Hessian matrix of the head s parameters. In the following, we provide two instances of our second-order aggregation with a linear head. 1. Linear Regression where Ri(w) = 1 Li PLi j=1(w xj i yj)2 is quadratic itself. Thus the second order Taylor expansion of the objective itself, i.e., Ji(w) = Ri(w). In this case, HRi(w) = X i Xi where Xi = [x1 i , , x Li i ] is the data matrix of client i. 2. Binary Classification where Ri(w) = 1 Li PLi j=1 yj i log σ(w xj i)+(1 yj i ) log(1 σ(w xj i)). σ is the sigmoid function. Let µj i σ(w xj i) denote model s output. The gradient and the Hessian are, w Ri(w) = 1 Li P j(µj i yj i )xj i = 1 Li 1 diag(µi yi)X i and HRi(w) = 1 Li X i SXi where S diag(µ1 i (1 µ1 i ), , µLi i (1 µLi i )). Similar formulas can be derived for the multiclass classification. Please refer to the text book (Murphy, 2022) for the exact equations. Remark. In practice, when the dimension of w is larger than the number of samples of a certain domain, the Hessian may have small singular values which cause numerical instability. To mitigate this issue, one can either directly set the representation dimension k to some smaller number or add a (fully-connected) projection layer on top of a pretrained encoder to compress the representations to a lower dimensional space. 4.4 THEORETICAL RESULT OF FEDDAR For a simplified linear regression setting as discussed in domain-mixed FL (4) (cf. details in Appendix A), we give below the sample complexity required for an adapted version of our algorithm (Algorithm 2 in the appendix) to enjoy linear convergence. Due to the space limit, we only provide an informal statement to highlight the result. The formal statement and the proof are deferred in the appendix. Theorem 4.1 (Sample complexity of Fed DAR convergence in linear case (informal)). Consider the linear setting for domain-mixed FL in (4). At each iteration, suppose that the number of samples used by each of n clients to update the encoder, is Ω( dk2 n ), and that the aggregate number of samples used in the update for the domain-specific heads, is Ω(k). Then, for a suitably chosen step size, the distance between the encoder Bt Algorithm 2 outputs and the true encoder B converges at a linear rate. Published as a conference paper at ICLR 2023 Remark. As our algorithm converges linearly to the true encoder, the per-iteration sample complexity of our algorithm gives a good estimate of the overall sample complexity. Since we expect the output of the encoder to be significantly lower-dimensional than the input (i.e. k d), our result indicates that Algorithm 2 s sample complexity is dominated by Ω( d n), implying that the complexity reduces significantly as the number of clients n increases. Moreover, a key implication of our result is the capacity for our algorithm to accommodate data imbalance across domains. We note that our approach requires Ω(dk2) samples per iteration for the update of the shared representation B Rd k, whilst needing only Ω(k) samples per iteration for the update of each domain head. In particular, domains with more data can contribute disproportionately to the Ω(dk2) samples required to learn the common representation, whilst domains with fewer data need only provide Ω(k) samples to update its domain head during the course of the algorithm. Whenever k d, which we believe is a reasonable assumption for many practical applications (e.g. medical imaging), the requirement of Ω(k) samples per domain is relatively mild. Conversely, forgoing the shared representation structure would require each domain to learn a separate d-dimensional classifier, requiring Ω(d) samples per domain, which can pose a challenge in problems with domain data imbalance. 5 EXPERIMENTS We validate our method s effectiveness on both synthetic and real datasets. We first experiment on the exact synthetic dataset described in our theoretical analysis to verify our theory. We then conduct experiments on a real dataset, Fair Face (Kärkkäinen & Joo, 2019), with controlled domain distributions to investigate the robustness of our algorithm under different levels of heterogeneity. Finally, we compare our method with various baselines on a real federated learning benchmark, EXAM (Dayan et al., 2021) with real-world domain distributions. We also conduct extensive ablation studies on it to discern the contribution of each component of our method. Full details of experimental settings can be found in Appendix B. 5.1 SYNTHETIC DATA We first run experiments on the linear regression problem analyzed in Appendix A. We generate (domain, data, label) samples as the following, zi M(πi), xi N(0, Id), yi N(w zi B xi, σ) where σ = 10 3 controls label observation errors, M(πi) is a multinomial domain distribution with parameter πi = [πi,1, ..., πi,M] M. The hyper-parameters of domain distributions πi are drawn from a Dirichlet distribution, i.e., πi Dir(αp), where p M is a prior domain distribution over M domains, and α > 0 is a concentration parameter controlling the heterogeneity of domain distributions among clients. The largest domain distribution heterogeneity is achieved as α 0 where each client contains data only from a single randomly selected domain. On the other hand, when α , all clients have identical domain distributions that are equal to the prior p. We generate ground-truth representation B Rd k and domain specific heads w m, m [M] by sampling and normalizing Gaussian matrices. Figure 5.2 shows result of our experiments where we set n = 100 clients, M = 5 domains, feature dimension k = 2. We vary the number of training samples per client from 5 to 20. The result shows that Fed DAR-SA, achieves four orders of magnitude smaller errors than all the baselines: (1) Local-Only where each client train a model using its own data; (2) Fed Avg which learns a single shared model; (3) Fed Rep which learns shared representation and client-specific heads. (4)Separate Fed Avg which trains separate models for each domain using Fed Avg. The results demonstrate that our method overcomes the heterogeneity of domain distributions across clients. Fed DAR-WA fails to converge under this setting, confirming the effectiveness of the proposed second-order aggregation. 5.2 REAL DATA WITH CONTROLLED DISTRIBUTION Dataset and Model. We use Fair Face (Kärkkäinen & Joo, 2019), a public face image dataset containing 7 race groups which are considered as the domains. Each image is labeled with one of 9 age groups and gender. We use the age label as the target to build a multi-class age classifier. We created an FL setting by dividing training data into n clients without duplication. Each client has a domain distribution πi Dir(αp) sampled from a Dirichlet distribution. The total number of samples at each client Li = 500 is set to be the same in all experiments. We control the heterogeneity of domain distributions by altering α. The label distributions are uniform for all the clients. Published as a conference paper at ICLR 2023 5 10 20 Number of training samples/user Average MSE Local Only Fed Avg Fed Rep Sep_Fed Avg Fed DAR-SA Figure 1: Performance under a different number of training samples per client, the error bars show the standard error from three independent runs. 4 8 16 32 64 k Average domain test accuracy Fed DAR-WA Fed DAR-SA Figure 2: Age classification accuracy as a function of representation dimension k. Table 1: Min, max and average test accuracy of age classification across 7 domains (race groups) on Fair Face with number of clients n = 5, number of samples at each client Li = 500 Method α = 0.1 α = 0.5 α = 1 α = 100 Max Min Avg Max Min Avg Max Min Avg Max Min Avg Seperate Fed Avg 41.8 2.2 4.1 1.3 19.3 0.5 42.4 0.6 9.7 0.6 23.5 1.2 41.4 0.7 9.3 0.8 23.4 0.8 41.4 0.4 8.5 1.6 22.9 1.0 Fed Avg (w/o reweighting) 43.4 1.7 37.1 1.1 40.0 0.6 45.7 0.5 38.8 0.2 41.6 0.3 45.2 0.8 38.4 0.8 41.1 0.4 44.0 0.7 38.6 0.7 40.8 0.7 Fed Avg(w/ reweighting) 42.8 0.6 37.7 0.8 39.9 0.2 45.8 0.8 39.2 0.2 41.6 0.3 44.9 0.5 38.3 0.5 40.9 0.3 43.3 1.3 38.4 1.5 40.5 1.3 Fed Avg + Multi-head 46.8 0.9 32.4 2.8 39.8 1.0 49.1 0.1 34.9 1.4 40.0 1.0 51.1 0.4 34.7 0.2 40.3 0.6 49.6 0.3 36.4 0.9 39.8 0.7 Fed DAR-WA 46.5 1.9 34.0 1.9 40.5 0.3 49.9 1.2 40.0 0.2 42.9 0.2 49.2 0.3 40.0 0.6 42.7 0.5 48.6 0.5 40.2 0.5 42.5 0.4 Fed DAR-SA 48.2 1.0 39.3 0.8 42.4 0.4 50.0 0.3 40.6 0.4 43.4 0.1 49.0 0.5 41.2 0.3 43.7 0.5 49.0 0.3 41.2 0.4 43.6 0.5 Implementation and Evaluation. We use Imagenet(Deng et al., 2009) pre-trained Res Net-34 (He et al., 2016) for all experiments on this dataset. All the methods are trained for T = 100 communication rounds. We use Adam optimizer with a learning rate of 1 10 4 for the first 60 rounds and 1 10 5 for the last 40 rounds. Metrics and Results. Our evaluation metrics are the classification accuracy on the whole validation set of Fair Face for each race group. We don t have extra local validation set for each client since we assume the data distribution within each domain is consistent across the clients. In Table 5.2, we report the accuracy averaged over the final 10 rounds of communication following the common practice (Collins et al., 2021). The result shows our Fed DAR achieved the best performance compared with the baselines. Note that Fed Avg + Multi-head also uses Equation 5 as objective for fair comparison. Effect of k. The limitation of using Fed DAR-SA instead of Fed DAR-WA is the need of tuning the dimension of representation k. Figure 5.2 shows results of the average domain test accuracy with different k. We can see that Fed DAR-SA can achieve better accuracy with a properly chosen k. We use k = 8 for all results with Fed DAR-SA in Table 5.2. Robustness to Varying Levels of Heterogeneity. From the result with various α, we can observe that the performance of Fed DAR-SA is very stable no matter how heterogeneous the domain mixtures are. However, the baselines accuracy decrease when α becomes smaller. 5.3 REAL DATA WITH REAL-WORLD DATA DISTRIBUTION Dataset and Model. We use the EXAM dataset (Dayan et al., 2021), a large-scale, real-world healthcare FL study. We use part of the dataset including 6 clients with a total of 7,681 cases. We use race groups as domains. The dataset is collected from suspected COVID-19 patients at the visit of the emergency department (ED), including both Chest X-rays (CXR) and electronic medical records (EMR). We adopt the same data preprocessing procedure and the model as (Dayan et al., 2021). Our task is to predict whether the patient received oxygen therapy higher than high-flow oxygen in 72 hours which indicates severe symptoms. Baselines. (1) methods that learn one global model, Fed Avg(Mc Mahan et al., 2017a), Fed Prox(Li et al., 2020), Fed Min Max(Papadaki et al., 2021) along with their local fine-tuned variants; (2) train M separate models with Fed Avg; (3) train one global model with Fed Avg first, then fine-tune on M domains separately with Fed Avg; (4) client-wise personalized FL approaches, Fed Rep(Collins et al., 2021), Fed Per(Arivazhagan et al., 2019), LG-Fedavg(Liang et al., 2020), Fed BN(Li et al., 2021b). Implementation and Evaluation. We apply 5-fold cross-validation. All the models are trained for T = 20 communication rounds with Adam optimizer and a learning rate of 10 4. The models are evaluated by aggregating predictions on the local validation sets and then calculating the area under curve (AUC) for each domain. We also report the AUCs averaged on clients local validation set. Published as a conference paper at ICLR 2023 Table 3: AUCs result on EXAM dataset with the domain being race group. Numbers are the means and standard deviations of metrics from 5-fold cross-validation. Methods White Black Asian Latino Other Min Avg Client Avg Local .761 .023 .815 .055 .838 .039 .889 .076 .840 .038 .759 .026 .829 .032 .795 .023 separate Fed Avg .796 .022 .694 .015 .788 .047 .649 .133 .826 .046 .606 .080 .751 .026 .759 .027 Fed Avg .830 .027 .854 .045 .887 .022 .834 .102 .900 .038 .773 .049 .861 .019 .856 .020 Fed Avg + FT .783 .044 .835 .025 .892 .015 .817 .136 .892 .048 .727 .093 .844 .024 .845 .016 Fed Avg + separate FT .832 .032 .846 .043 .903 .025 .869 .099 .911 .026 .784 .054 .872 .017 .863 .024 Fed Prox .834 .017 .864 .056 .903 .035 .880 .085 .912 .030 .808 .030 .879 .023 .868 .012 Fed Prox + FT .806 .023 .842 .049 .910 .025 .925 .085 .898 .031 .798 .025 .876 .010 .858 .014 Fed Min Max .839 .027 .867 .054 .894 .039 .916 .053 .903 .034 .823 .032 .884 .020 .872 .016 Fed BN .787 .027 .840 .063 .883 .039 .867 .090 .852 .043 .766 .013 .846 .020 .856 .010 Fed Rep .837 .020 .869 .050 .888 .042 .913 .083 .910 .028 .812 .028 .884 .025 .867 .013 Fed Per .835 .025 .865 .073 .909 .037 .916 .036 .911 .031 .813 .047 .887 .021 .873 .011 LG-Fed Avg .830 .029 .858 .052 .906 .032 .902 .050 .903 .033 .814 .034 .880 .019 .867 .017 Fed DAR-WA .884 .007 .896 .017 .902 .034 .952 .041 .928 .022 .872 .015 .912 .004 .898 .006 Fed DAR-SA .888 .004 .895 .038 .928 .032 .939 .046 .948 .016 .868 .020 .919 .014 .912 .001 Average Performance Across Domains and Clients. Table 3 shows the average of AUCs across domains and clients. We can see that our methods, both Fed DAR-WA and Fed DAR-SA, achieve significantly better performance than all the baselines under both domain-wise and client-wise metrics. The gap between our domain-wise personalized approach and other client-wise personalized baselines shows the validity of learning domain-wise personalized models facing diversity across domains. The reason that fine-tuning methods induce worse results is mainly because of the imbalanced label distribution. Each local training dataset doesn t have enough positive cases to do proper fine-tuning. Fairness Across Domains. The AUCs of each specific domain in Table 3, show that our proposed Fed DAR method uniformly increases the AUC for each domain. The column of the minimum AUC among domains also verifies that our method indeed improves the fairness across the domains. Table 2: Ablation results of different components contribution in Fed DAR. RW MH DI Alter Proj AGG Domain Avg / Min Client Avg N/A .861 / .773 .856 N/A .881 / .824 .873 N/A .880 / .825 .866 WA .885 / .834 .870 WA .877 / .817 .870 SA .878 / .826 .871 N/A .867 / .806 .852 WA .912 / .872 .898 WA .918 / .863 .904 SA .919 / .868 .912 Ablation Studies. i) re-weighting (RW): First two rows in Table 2 shows adding sample re-weighting significantly improves the fairness across the domains. The minimum AUC among domains is improved by a large margin (> 0.05); ii) multi-head (MH), domain as input feautre (DI) and alternating update (Alter): Comparing three blocks in Table 2, we see that adding multi-head alone does not improve results. We conjecture that alternating update prevents the overfitting of the heads with limited samples. This is also shown by the result in Table 5.2, where Fed Avg+MH tends to perform badly on certain underrepresented domains especially when domain distributions are highly heterogeneous (α is small). Meanwhile, using domain labels directly as feature input is not as good as multi-head, and not compatible with alternating update; iii) projection (Proj) and aggregation method (AGG): Results in Table 2 shows that using second-order aggregation with the projection of the features gives the best result. 6 CONCLUSIONS We propose a novel personalized federated learning framework that assumes the mixture of domain data distribution. Our approach, Fed DAR, achieves a balanced performance across domains by learning a global representation and domain-specific heads, despite the heterogeneity of domain distributions across clients. Our method is effective, as supported by both theoretical and empirical justifications. It has been tested on face recognition and medical imaging FL datasets and can be easily extended to other complicated tasks. However, our method has some limitations: i) it requires the domain information for all samples; ii) it does not consider heterogeneity of label distributions; iii) it has a potentially expensive communication cost caused by sending Hessian matrices, especially when the output dimension is big. We plan to address these limitations in future work, along with other research directions such as improving fairness across domains and exploring the setting where domains are structured, hierarchical, continuously indexed (Wang et al., 2020; Nasery et al., 2021) or multi-dimensional (characterized by multiple factors). Published as a conference paper at ICLR 2023 ACKNOWLEDGMENTS This work has been supported by NIH 1R01HL159183. Manoj Ghuhan Arivazhagan, Vinay Aggarwal, Aaditya Kumar Singh, and Sunav Choudhary. Federated learning with personalization layers. ar Xiv preprint ar Xiv:1912.00818, 2019. Arthur Asuncion and David Newman. Uci machine learning repository, 2007. Fei Chen, Zhenhua Dong, Zhenguo Li, and Xiuqiang He. Federated meta-learning for recommendation. ar Xiv preprint ar Xiv:1802.07876, 2018. Hong-You Chen and Wei-Lun Chao. On bridging generic and personalized federated learning for image classification. In International Conference on Learning Representations, 2021. Jung Hee Cheon, Andrey Kim, Miran Kim, and Yongsoo Song. Homomorphic encryption for arithmetic of approximate numbers. In International conference on the theory and application of cryptology and information security, pp. 409 437. Springer, 2017. Lingyang Chu, Lanjun Wang, Yanjie Dong, Jian Pei, Zirui Zhou, and Yong Zhang. Fedfair: Training fair models in cross-silo federated learning. ar Xiv preprint ar Xiv:2109.05662, 2021. Francesco Ciompi, Oscar Geessink, Babak Ehteshami Bejnordi, Gabriel Silva De Souza, Alexi Baidoshvili, Geert Litjens, Bram Van Ginneken, Iris Nagtegaal, and Jeroen Van Der Laak. The importance of stain normalization in colorectal tissue classification with convolutional networks. In 2017 IEEE 14th International Symposium on Biomedical Imaging (ISBI 2017), pp. 160 163. IEEE, 2017. Liam Collins, Hamed Hassani, Aryan Mokhtari, and Sanjay Shakkottai. Exploiting shared representations for personalized federated learning. In International Conference on Machine Learning, pp. 2089 2099. PMLR, 2021. Luca Corinzia, Ami Beuret, and Joachim M Buhmann. Variational federated multi-task learning. ar Xiv preprint ar Xiv:1906.06268, 2019. Sen Cui, Weishen Pan, Jian Liang, Changshui Zhang, and Fei Wang. Addressing algorithmic disparity and performance inconsistency in federated learning. Advances in Neural Information Processing Systems, 34, 2021. Ittai Dayan, Holger R Roth, Aoxiao Zhong, Ahmed Harouni, Amilcare Gentili, Anas Z Abidin, Andrew Liu, Anthony Beardsworth Costa, Bradford J Wood, Chien-Sung Tsai, et al. Federated learning for predicting clinical outcomes in patients with covid-19. Nature medicine, 27(10): 1735 1743, 2021. Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp. 248 255. Ieee, 2009. Yuyang Deng, Mohammad Mahdi Kamani, and Mehrdad Mahdavi. Adaptive personalized federated learning. ar Xiv preprint ar Xiv:2003.13461, 2020. Frances Ding, Moritz Hardt, John Miller, and Ludwig Schmidt. Retiring adult: New datasets for fair machine learning. Advances in Neural Information Processing Systems, 34:6478 6490, 2021. Wei Du, Depeng Xu, Xintao Wu, and Hanghang Tong. Fairness-aware agnostic federated learning. In Proceedings of the 2021 SIAM International Conference on Data Mining (SDM), pp. 181 189. SIAM, 2021. Alireza Fallah, Aryan Mokhtari, and Asuman Ozdaglar. Personalized federated learning: A metalearning approach. ar Xiv preprint ar Xiv:2002.07948, 2020. Published as a conference paper at ICLR 2023 Borja Rodríguez Gálvez, Filip Granqvist, Rogier van Dalen, and Matt Seigel. Enforcing fairness in private federated learning via the modified method of differential multipliers. In Neur IPS 2021 Workshop Privacy in Machine Learning, 2021. Yaroslav Ganin and Victor Lempitsky. Unsupervised domain adaptation by backpropagation. In International conference on machine learning, pp. 1180 1189. PMLR, 2015. Lidia Garrucho, Kaisar Kushibar, Socayna Jouide, Oliver Diaz, Laura Igual, and Karim Lekadir. Domain generalization in deep learning-based mass detection in mammography: A large-scale multi-center study. ar Xiv preprint ar Xiv:2201.11620, 2022. Avishek Ghosh, Jichan Chung, Dong Yin, and Kannan Ramchandran. An efficient framework for clustered federated learning. Advances in Neural Information Processing Systems, 33:19586 19597, 2020. Gene H Golub and Charles F Van Loan. Matrix computations. JHU press, 2013. Hao Guan and Mingxia Liu. Domain adaptation for medical image analysis: a survey. IEEE Transactions on Biomedical Engineering, 69(3):1173 1185, 2021. Filip Hanzely and Peter Richtárik. Federated learning of a mixture of global and local models. ar Xiv preprint ar Xiv:2002.05516, 2020. Filip Hanzely, Slavomír Hanzely, Samuel Horváth, and Peter Richtárik. Lower bounds and optimal algorithms for personalized federated learning. Advances in Neural Information Processing Systems, 33:2304 2315, 2020. Moritz Hardt, Eric Price, and Nati Srebro. Equality of opportunity in supervised learning. Advances in neural information processing systems, 29, 2016. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. Yutao Huang, Lingyang Chu, Zirui Zhou, Lanjun Wang, Jiangchuan Liu, Jian Pei, and Yong Zhang. Personalized cross-silo federated learning on non-iid data. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pp. 7865 7873, 2021. Jonathan J. Hull. A database for handwritten text recognition research. IEEE Transactions on pattern analysis and machine intelligence, 16(5):550 554, 1994. Heinrich Jiang and Ofir Nachum. Identifying and correcting label bias in machine learning. In International Conference on Artificial Intelligence and Statistics, pp. 702 712. PMLR, 2020. Yihan Jiang, Jakub Koneˇcn y, Keith Rush, and Sreeram Kannan. Improving federated learning personalization via model agnostic meta learning. ar Xiv preprint ar Xiv:1909.12488, 2019. Peter Kairouz, H Brendan Mc Mahan, Brendan Avent, Aurélien Bellet, Mehdi Bennis, Arjun Nitin Bhagoji, Kallista Bonawitz, Zachary Charles, Graham Cormode, Rachel Cummings, et al. Advances and open problems in federated learning. ar Xiv preprint ar Xiv:1912.04977, 2019. Peter Kairouz, Brendan Mc Mahan, Shuang Song, Om Thakkar, Abhradeep Thakurta, and Zheng Xu. Practical and private (deep) learning without sampling or shuffling. In International Conference on Machine Learning, pp. 5213 5225. PMLR, 2021. Kimmo Kärkkäinen and Jungseock Joo. Fairface: Face attribute dataset for balanced race, gender, and age. ar Xiv preprint ar Xiv:1908.04913, 2019. Mikhail Khodak, Maria-Florina Balcan, and Ameet Talwalkar. Adaptive gradient-based meta-learning methods. ar Xiv preprint ar Xiv:1906.02717, 2019. Yann Le Cun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. Tian Li, Maziar Sanjabi, Ahmad Beirami, and Virginia Smith. Fair resource allocation in federated learning. ar Xiv preprint ar Xiv:1905.10497, 2019. Published as a conference paper at ICLR 2023 Tian Li, Anit Kumar Sahu, Ameet Talwalkar, and Virginia Smith. Federated learning: Challenges, methods, and future directions. IEEE Signal Processing Magazine, 37(3):50 60, 2020. Tian Li, Shengyuan Hu, Ahmad Beirami, and Virginia Smith. Ditto: Fair and robust federated learning through personalization. In International Conference on Machine Learning, pp. 6357 6368. PMLR, 2021a. Xiaoxiao Li, Meirui JIANG, Xiaofei Zhang, Michael Kamp, and Qi Dou. Fed{bn}: Federated learning on non-{iid} features via local batch normalization. In International Conference on Learning Representations, 2021b. URL https://openreview.net/forum?id=6YEQUn0QICG. Paul Pu Liang, Terrance Liu, Liu Ziyin, Nicholas B Allen, Randy P Auerbach, David Brent, Ruslan Salakhutdinov, and Louis-Philippe Morency. Think locally, act globally: Federated learning with local and global representations. ar Xiv preprint ar Xiv:2001.01523, 2020. Zhengquan Luo, Yunlong Wang, Zilei Wang, Zhenan Sun, and Tieniu Tan. Disentangled federated learning for tackling attributes skew via invariant aggregation and diversity transferring. ar Xiv preprint ar Xiv:2206.06818, 2022. Yishay Mansour, Mehryar Mohri, Jae Ro, and Ananda Theertha Suresh. Three approaches for personalization with applications to federated learning. ar Xiv preprint ar Xiv:2002.10619, 2020. Othmane Marfoq, Giovanni Neglia, Laetitia Kameni, and Richard Vidal. Personalized federated learning through local memorization. ar Xiv preprint ar Xiv:2111.09360, 2021. Gustav Mårtensson, Daniel Ferreira, Tobias Granberg, Lena Cavallin, Ketil Oppedal, Alessandro Padovani, Irena Rektorova, Laura Bonanni, Matteo Pardini, Milica G Kramberger, et al. The reliability of a deep learning model in clinical out-of-distribution mri data: a multicohort study. Medical Image Analysis, 66:101714, 2020. Brendan Mc Mahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pp. 1273 1282. PMLR, 2017a. H Brendan Mc Mahan, Daniel Ramage, Kunal Talwar, and Li Zhang. Learning differentially private recurrent language models. ar Xiv preprint ar Xiv:1710.06963, 2017b. Aditya Krishna Menon and Robert C Williamson. The cost of fairness in binary classification. In Conference on Fairness, Accountability and Transparency, pp. 107 118. PMLR, 2018. Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. In International Conference on Machine Learning, pp. 4615 4625. PMLR, 2019. Kevin P Murphy. Probabilistic machine learning: an introduction. MIT press, 2022. Anshul Nasery, Soumyadeep Thakur, Vihari Piratla, Abir De, and Sunita Sarawagi. Training for the future: A simple gradient interpolation loss to generalize along time. Advances in Neural Information Processing Systems, 34, 2021. Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y Ng. Reading digits in natural images with unsupervised feature learning. 2011. NHS. Health survey for england - 2004, health of ethnic minorities, 2004. URL https://digital.nhs.uk/data-and-information/ publications/statistical/health-survey-for-england/ health-survey-for-england-2004-health-of-ethnic-minorities-main-report. Afroditi Papadaki, Natalia Martinez, Martin Bertran, Guillermo Sapiro, and Miguel Rodrigues. Federating for learning group fair models. ar Xiv preprint ar Xiv:2110.01999, 2021. Meghna Ranganathan and Raj Bhopal. Exclusion and inclusion of nonwhite ethnic minority groups in 72 north american and european cardiovascular cohort studies. PLo S medicine, 3(3):e44, 2006. Nicola Rieke, Jonny Hancox, Wenqi Li, Fausto Milletari, Holger R Roth, Shadi Albarqouni, Spyridon Bakas, Mathieu N Galtier, Bennett A Landman, Klaus Maier-Hein, et al. The future of digital health with federated learning. NPJ digital medicine, 3(1):1 7, 2020. Published as a conference paper at ICLR 2023 Yuji Roh, Kangwook Lee, Steven Euijong Whang, and Changho Suh. Fairbatch: Batch selection for model fairness. ar Xiv preprint ar Xiv:2012.01696, 2020. Felix Sattler, Klaus-Robert Müller, and Wojciech Samek. Clustered federated learning: Modelagnostic distributed multitask optimization under privacy constraints. IEEE transactions on neural networks and learning systems, 32(8):3710 3722, 2020. Aviv Shamsian, Aviv Navon, Ethan Fetaya, and Gal Chechik. Personalized federated learning using hypernetworks. In International Conference on Machine Learning, pp. 9489 9502. PMLR, 2021. Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, and Ameet S Talwalkar. Federated multi-task learning. Advances in neural information processing systems, 30, 2017. Ala Szczepura. Access to health care for ethnic minority populations. Postgraduate medical journal, 81(953):141 147, 2005. Canh T Dinh, Nguyen Tran, and Josh Nguyen. Personalized federated learning with moreau envelopes. Advances in Neural Information Processing Systems, 33:21394 21405, 2020. Nilesh Tripuraneni, Chi Jin, and Michael Jordan. Provable meta-learning of linear representations. In International Conference on Machine Learning, pp. 10434 10443. PMLR, 2021. Paul Vanhaesebrouck, Aurélien Bellet, and Marc Tommasi. Decentralized collaborative learning of personalized models over networks. In Artificial Intelligence and Statistics, pp. 509 517. PMLR, 2017. Roman Vershynin. High-dimensional probability: An introduction with applications in data science, volume 47. Cambridge university press, 2018. Hao Wang, Hao He, and Dina Katabi. Continuously indexed domain adaptation. ar Xiv preprint ar Xiv:2007.01807, 2020. Kangkang Wang, Rajiv Mathews, Chloé Kiddon, Hubert Eichner, Françoise Beaufays, and Daniel Ramage. Federated evaluation of on-device personalization. ar Xiv preprint ar Xiv:1910.10252, 2019. Michael Wick, Jean-Baptiste Tristan, et al. Unlocking fairness: a trade-off revisited. Advances in neural information processing systems, 32, 2019. An Xu, Wenqi Li, Pengfei Guo, Dong Yang, Holger Roth, Ali Hatamizadeh, Can Zhao, Daguang Xu, Heng Huang, and Ziyue Xu. Closing the generalization gap of cross-silo federated medical image segmentation. ar Xiv preprint ar Xiv:2203.10144, 2022. Wenjun Yan, Yuanyuan Wang, Shengjia Gu, Lu Huang, Fuhua Yan, Liming Xia, and Qian Tao. The domain shift problem of medical image segmentation and vendor-adaptation by unet-gan. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 623 631. Springer, 2019. Tao Yu, Eugene Bagdasaryan, and Vitaly Shmatikov. Salvaging federated learning by local adaptation. ar Xiv preprint ar Xiv:2002.04758, 2020. Xubo Yue, Maher Nouiehed, and Raed Al Kontar. Gifair-fl: An approach for group and individual fairness in federated learning. ar Xiv preprint ar Xiv:2108.02741, 2021. Muhammad Bilal Zafar, Isabel Valera, Manuel Gomez Rogriguez, and Krishna P Gummadi. Fairness constraints: Mechanisms for fair classification. In Artificial Intelligence and Statistics, pp. 962 970. PMLR, 2017. Valentina Zantedeschi, Aurélien Bellet, and Marc Tommasi. Fully decentralized joint learning of personalized models and collaboration graphs. In International Conference on Artificial Intelligence and Statistics, pp. 864 874. PMLR, 2020. Rich Zemel, Yu Wu, Kevin Swersky, Toni Pitassi, and Cynthia Dwork. Learning fair representations. In International conference on machine learning, pp. 325 333. PMLR, 2013. Published as a conference paper at ICLR 2023 Yuchen Zeng, Hongxu Chen, and Kangwook Lee. Improving fairness via federated learning. ar Xiv preprint ar Xiv:2110.15545, 2021. Daniel Yue Zhang, Ziyi Kou, and Dong Wang. Fairfl: A fair federated learning approach to reducing demographic bias in privacy-sensitive classification models. In 2020 IEEE International Conference on Big Data (Big Data), pp. 1051 1060. IEEE, 2020. Han Zhao and Geoff Gordon. Inherent tradeoffs in learning fair representations. Advances in neural information processing systems, 32, 2019. Chen Zhu, Zheng Xu, Mingqing Chen, Jakub Koneˇcn y, Andrew Hard, and Tom Goldstein. Diurnal or nocturnal? federated learning of multi-branch networks from periodically shifting distributions. In International Conference on Learning Representations, 2021. Ligeng Zhu, Zhijian Liu, and Song Han. Deep leakage from gradients. Advances in neural information processing systems, 32, 2019. Published as a conference paper at ICLR 2023 A FEDDAR FOR LINEAR REPRESENTATION We retain the setup for linear regression considered at the start of Section 3.1. We additionally define W [w 1, , w M] RM k as the concatenation of domain specific heads. For notational convenience, we let (xi,m, yi,m) denote an (input, output) sample coming from client i and the m-th domain. To measure the distance between any two matrices A, B with the same dimensions, we use the principal angle distance (Golub & Van Loan, 2013), given by dist(A, B) A B 2, where A denotes a matrix whose columns form a basis for the orthogonal complement of the range of A. To simplify the analysis, we further make the following assumptions. Assumption A.1 (Sub-Gaussianilty). For each m [M] and i [n], the samples xi,m Rd are independent, mean zero, have covariance Id, and has subgaussian norm 1, i.e. for every v Rd, E[exp(v xi,m)] exp( v 2/2). Assumption A.2 (Domain diversity). Let σmin, σmin( 1 M W ), i.e., σmin, is the minimum singular value of the head matrix. Then σmin, > 0. Assumption A.3 (Ground truth normalization). The true domain parameters satisfy 1 k for each m [M], and B has orthonormal columns. All the above assumptions aim to simplify the theoretical analysis whilst only imposing mild constraints on the data distribution and the parameters of the target functions. Similar assumptions have also been adapted in prior work (Collins et al., 2021). A.2 FEDDAR ADAPTED TO LINEAR REGRESSION We analyze an adapted version of our Fed DAR algorithm. Since the linear regression problem has an analytic solution, to ease analysis, we update the heads {wm}M m=1 at the server in closed form using local gradient information. Meanwhile, we update the representation B by taking a step using the averaged local gradients. Algorithm 2 shows the procedure of this adapted version. The local objective for i-th client in m-th domain at t-th iteration, f t i,m(wm, Bt) is defined as the following, f t i,m(wm, Bt) 1 j=1 (yj i,m w m B xj i,m)2, where Lt i,m is the number of samples from domain m at client i. We assume in each iteration the data points {xj i,m, yj i,m}j [Lt i,m] are all newly sampled from the distribution. We denote L = P m Lt i,m. Note that since the objective function has a quadratic form, thus its gradient w.r.t either wm or B has a linear form of Ai,mwm ai,n or Ci,m B ci,m which we write down explicitly in Appendix B. After every global update of the representation B, we apply an additional QR decomposition to normalize it to be column-wise orthogonal. A.3 CONVERGENCE ANALYSIS We first provide a brief proof sketch. Overall, our approach largely follows that in Collins et al. (2021), with a few differences needed to handle the spreading of a domain s data across different clients. We note that we also tightened the analysis compared with Collins et al. (2021), such that each domain only needs O(k) samples as opposed to O(k2) samples as in Collins et al. (2021) (where the requirement is for each client to have O(k2) samples since they considered the case where each client has a separate head). This can yield a significant improvement when k is moderately large and there is data imbalance. 1. First, in Lemma A.5, we show that our estimated weight matrix W t+1 RM k (which is our estimation at time t + 1 of the true domain weights matrix W ) satisfies the relationship W t+1 = W (B ) Bt + Ft, Published as a conference paper at ICLR 2023 Algorithm 2 FEDDAR for linear regression Input: Step size η; number of rounds T Client initialization: each agent i [n] collects L0 samples, and sends Zi := PL0 i=1(y0,j i )2x0,j i (x0,j i ) to the server. Server initialization: finds UDU rank-k SVD( 1 n L0 Pn i Zi); sets B0 U. for t = 0, 1, . . . , T do Server sends current Bt to clients. Client computation for W t+1: for client i [n] do Selects L new samples {(xj i, yj i )}. Computes wmf t i,m(wm, Bt) = At i,mwm at i,m for each domain m [M]. Sends (At i,m, at i,m, Lt i,m) back to server. end for Server update for W t+1: Server chooses wt+1 m n wm Rk : wm 1 P i Lt i,m Pn i=1 f t i,m(wm, Bt) = 0 o , m [M], i.e., wt+1 m that satisfies (P i At i,m)wt+1 m = P i at i,m. Sends W t+1 = [w1, , w M] RM k to clients. Client computation for Bt+1: for client i [n] do Selects L new samples {xj i, yj i }. Computes Bf t i,m(wt+1 m , Bt) = Ct i,m Bt ct i,m for each m [M]. Sends ( Bf t i,m(wt+1 m , Bt), Lt i,m) back to server. end for Server update for Bt+1: Server computes Bt+1 Bt η 1 m PM m=1 1 P i Lt i,m Pn i=1 Bf t i,m(wt+1 m , B). Server performs QR decomposition ˆ Bt+1, Rt+1 = QR( Bt+1). Server updates Bt+1 ˆ Bt+1. end for where Bt is our estimation of the true (high to low dimensional) representation embedding B at time t, and Ft is an error term that we can show to be bounded (sufficiently small) in terms of dist(Bt, B ), with the scale of the error depending on the (random) number of samples Lm seen at time t for each domain m [M]. Bounding Ft in our setting requires some care since the samples for each domain are spread over many clients. We refer the reader to Section A.4.2 for the details. 2. Second, we show in Section A.4.3 that the update for Bt satisfies the relationship (see equation 21) Bt+1 = Bt η M (Qt) W t+1 η where Qt := W t+1(Bt) W (B ) , and HQ denotes an error term which can be shown to be bounded (sufficiently small) in terms of dist(Bt, B ). Further simplifying, we have dist(Bt+1, B ) = (B ) Bt+1 = (B ) Bt η M Bt(W t+1) B (W ) W t+1 η = dist(Bt, B ) η M (B ) Bt(W t+1) W t+1 η dist(Bt, B ) η M dist(Bt, B )σ2 min(W t+1) η By upper bounding HQ in terms of dist(Bt, B ) and providing an appropriate lower bound on σmin(W t+1Wt+1), we can then show by picking η sufficiently small and under other suitable assumptions, the quantity dist(Bt, B ) decays at a linear rate (see Equation 29). Again, a difference from the analysis in Collins et al. (2021) is our handling of bounding HQ , since the samples for each sample are spread over different clients. Published as a conference paper at ICLR 2023 3. An issue created by the spreading of samples for a domain across different clients is how to pick an appropriate sample size such that the Ft and HQ terms can be suitably bounded. In Lemma A.10, we prescribe a suitable sample size for each client such that each domain gets sufficient samples (with high probability) for the purposes of our analysis. We now proceed with a detailed analysis. We first present a theorem that states our adapted Fed DAR(Algorithm 2) enjoys linear convergence. The theorem is followed by multiple remarks which highlight key detailed points of our convergence result. Theorem A.4 (Algorithm 2 convergence). Define E0 := 1 dist2(B0, B ), σmax, := M W , σmin, := σmin 1 M W . Let κ := σmax, σmin, . Suppose L Ω max dk2κ4 n E2 0 , k2κ4 E2 0 minm [M](Pm i=1 πi,m) Then, for any T and any η 1/(4 σ2 max, ), with probability at least 1 Te 80, dist(BT , B ) (1 ηE0 σ2 min, /2)T/2dist(B0, B ). (10) Linear convergence speed: The convergence of BT to B is linear, assuming that (1) σmin( 1 M W ) > 0 and that (2) 1 ηE0 σ2 min (0, 1). Initialization of B0: For our convergence result to be meaningful, we need dist(B0, B ) to be close to 0. We show in Appendix A that our algorithm s choice of initial B0 ensures that dist(B0, B ) is close enough to 0 whilst preserving privacy. When the number of samples is uniform across the domains, this comes only at the cost of a logarithmic increase in sample complexity. Sample complexity: The per-iteration sample complexity per client is L. We note that in the requirement for L (9), we need that L Ω(dk2κ4/n); this comes from the updates for Bt Rd k. While we expect that d could be large, a large number of clients n helps to mitigate the increase in sample complexity arising from d. We also need L Ω(k2κ4 Pm i=1 πi,m) for every domain m [M]; this requirement comes from the updates for wt m for each of the M domains. A.4 PROOF OF THEOREM A.4 A.4.1 ANALYSIS FOR UPDATE OF W t+1 Since we are analyzing the update step for any iteration t, unless necessary we drop all t superscripts. Let Lm = Pn i=1 Li,m denote the number of samples from domain m [M] across the n clients. Then, we can express wm Pn i=1 fi,m(wm, B) as i=1 fi,m(wm, B) = j=1 (w m B xj i,m yj i,m)B xj i,m. Since yj i,m = (w m) (B ) xj i,m, it follows that following Algorithm 2, B xj i,m(xj i,m) B wt+1 m = 1 Lm B xj i,m(xj i,m) B w m. (11) Reexpressing, assuming Gm is invertible, we have wt+1 m = B B w m + B xj i,m(xj i,m) B w m Published as a conference paper at ICLR 2023 Intuitively, assuming Lm is large enough, j=1 xj i,m(xj i,m) Id. B xj i,m(xj i,m) B w m This then implies that W t+1 = W (B ) B + F, (13) where the m-th row of F is B xj i,m(xj i,m) B w m Note the similarity of equation 18 to (17) in (Collins et al., 2021). Following a similar analysis as (Collins et al., 2021), we should also be able to bound the Frobenius norm of F in terms of dist(B, B ). Below, we formalize the argument. First, we have the following lemma. Lemma A.5 (Update for W t+1). For each time t, let Lt m := Pn i=1 Lt i,m denote the number of samples from domain m [M] across the n clients at time t. For convenience, we drop the time index unless absolutely necessary. We define the terms j=1 xj i,m(xj i,m) , Gm := 1 Lm B xj i,m(xj i,m) B . Then, assuming that Gm is invertible, the update for W takes the form W t+1 = W (B ) B + F, (14) where the m-th row of F is F m := G 1 m B Xm(I BB )B w m . (15) Proof. We can express wm Pn i=1 fi,m(wm, B) as i=1 fi,m(wm, B) = j=1 (w m B xj i,m yj i,m)B xj i,m. Since yj i,m = (w m) (B ) xj i,m, it follows that following Algorithm 2, B xj i,m(xj i,m) B wt+1 m = 1 Lm B xj i,m(xj i,m) B w m. (16) Reexpressing, assuming Gm is invertible, we have wt+1 m = B B w m + B xj i,m(xj i,m) B w m Published as a conference paper at ICLR 2023 This then implies that W t+1 = W (B ) B + F, (18) where the m-th row of F is B xj i,m(xj i,m) B w m = G 1 m B Xm B w m G 1 m Gm B B w m = G 1 m B Xm B w m G 1 m B Xm BB B w m = G 1 m B Xm(I BB )B w m . A.4.2 BOUNDING F F We will proceed to bound the Frobenius norm of F. We begin by showing that G 1 m exists and (both lower and upper) bounding its spectral norm. Lemma A.6. Let Lmin := minm [M] Lm. Let δk := 10C log(M) Lmin for some absolute constant C. Suppose that 0 δk < 1. Then, with probability at least 1 e 80k log(M), G 1 m exists for each m [M], and G 1 m 2 1 1 δk m [M]. Proof. Note that B xj i,m(xj i,m) B . Let vj i,m := B xj i,m. Since B B = I, it follows that each vj i,m is i.i.d 1-subgaussian. Then, applying the same argument in Theorem 4.6.1 of Vershynin 2018, we have (cf. equation (4.22) in Vershynin 2018) σmin(Gm) 1 C k Lm + z Lm | {z } δk,m with probability at least 1 e z2 for z 0 and some absolute constant C, assuming that 0 δk,m 1. Consider the choice z = 9 log(M). Then, k log(M) Lm k log M Lm 10C k log M Lmin . Suppose we choose Lmin 1 such that δk,m < 1. Then, taking a union bound, with probability at least 1 Me z2 = 1 M exp( 81k} log(M))=1 exp( 80k log(M)), σmin(Gm) 1 δk,m 1 10C k log M Lmin > 0 m [M]. (20) Therefore, with probability at least 1 exp( 80k log(M)), G 1 m exists for every m [M], and in addition, G 1 m 2 1 1 δk m [M]. Published as a conference paper at ICLR 2023 We next bound the operator norm of term B Xm(I BB )B . Lemma A.7. Let Lmin := minm [M] Lm. Let δk := 10C k log M Lmin for some absolute constant C. Suppose Lmin is such that 0 δk < 1. Then, with probability at least 1 e 50k log M, B Xm(I BB )B 2 dist(B , B)δk. Proof. We will use an ϵ-net argument, similar to the proof of Theorem 4.6.1 in (Vershynin, 2018). First, by Corollary 4.2.13 in (Vershynin, 2018), there exists an 1/4-net N of the unit sphere Sk 1 with cardinality N 9k. Using Lemma 4.4.1 in (Vershynin, 2018), we have that B Xm(I BB )B 2 2 max z N B Xm(I BB )B z, z . To prove our result, by applying a union bound over m [M], it suffices to show that with the probability at least 1 e 100k2 log M, max z N B Xm(I BB )B z, z δkm where we recall that k log(M) Lm We will assume that minm Lm := Lmin 1 is chosen large enough such that δk,m 1. For a fixed z Sk 1, observe that B Xm(I BB )B z, z = 1 Lm D B xj i,m(xj i,m) (I BB )B z, z E j=1 (z uj i,m)((vj i,m) z), where we defined uj i,m := B xj i,m, and vj i,m = (B ) (I BB )xj i,m. Since each xj i,m is 1-subgaussian, B 2 = 1, and (I BB )B 2 = dist(B , B), it follows that z uj i,m is subgaussian with norm at most 1, and (vj i,m) z is subgaussian with norm at most dist(B , B). Thus, the random variable αj i,m := (z uj i,m)((vj i,m) z) (for a fixed unit z) is subexponential with sub-exponential norm at most dist(B , B). Moreover, note that αj i,m is mean-zero, since E[uj i,m(vj i,m) ] = E[B xj i,m(xj i,m) (I BB )B ] = B (I BB )B = 0, as xj i,m is assumed to have identity covariance. Thus, the αj i,m s are i.i.d mean-zero subexponential variables each with subexponential norm at most dist(B , B). Hence, by Bernstein s inequality (cf. Corollary 2.8.3 in (Vershynin, 2018)), P B Xm(I BB )B z, z δk,mdist(B , B) δk,mdist(B , B) c min(δk,mdist(B , B) dist(B , B) , δk,mdist(B , B) dist(B , B) = 2 exp( cδ2 k,m Lm) 2 exp( c C2 (k + 81k log(M) log(M))). Published as a conference paper at ICLR 2023 Above we used the assumption that δk,m 1 to simplify the minimum operator in the exponent. Taking a union bound over each z N, it follows that P B Xm(I BB )B 2 δk,mdist(B , B) P 2 max z N B Xm(I BB )B z, z δk,mdist(B , B) 2 9k exp( c C2 (k + 81k log(M))) exp( 51k log M), where the last inequality follows by picking C large enough (but still it is an absolute constant). (Here the choice of 51 in the exponent is somewhat arbitrary; any choice smaller than 81 should work). By applying a union bound over the domains m [M], this then completes our proof. We are now finally ready to bound F F . Lemma A.8. Let Lmin := minm [M] Lm. Let δk := 10C log(M) Lmin for some absolute constant C. Suppose that 0 δk < 1. Then, with probability at least 1 2e 50k log(M), F F δk 1 δk dist(B , B) W F . Proof. By Lemma A.6 and Lemma A.7, we have that with probability at least 1 2e 50k log M, G 1 m (B Xm(I BB )B ) 2 G 1 m 2 B Xm(I BB )B 2 1 1 δk (δkdist(B , B)) . The proof then follows by recalling that the m-th row, F m, takes the form F m = G 1 m B Xm(I BB )B w m . A.4.3 ANALYSIS OF UPDATE FOR Bt+1 Similarly to (Collins et al., 2021), we define Qt = W t+1(Bt) (W )(B ) . Below, we drop the time index and use B, Q, W to denote Bt, Qt, and W t+1 respectively. Based on algorithm 2, we have that j=1 (w m B xj i,m yj i,m)xj i,mw m D Aj i,m, WB E D Aj i,m, W (B ) E (Aj i,m) W, Aj i,m := em(xj i,m) D Aj i,m, Q E (Aj i,m) W j=1 xj i,m(xj i,m) qmw m j=1 xj i,m(xj i,m) qmw m η Published as a conference paper at ICLR 2023 Above, we define qm Rd to denote the m-th row of Q (viewed as a column vector). Note again that since j=1 xj i,m(xj i,m) Id, the term HQ in equation 21 can be appropriately bounded. Note the resemblance of equation 21 to (53) in (Collins et al., 2021); the crucial difference is that we will need to lower bound 1 mσ2 min(W ), instead of 1 nσ2 min(W ) as in (Collins et al., 2021). Thus we should be able to carry out the rest of the analysis in a similar way to the outline in (Collins et al., 2021) and derive an analogous result to Theorem 1 in (Collins et al., 2021). We first bound the error term HQ. Lemma A.9. Let j=1 xj i,m(xj i,m) qm(wt+1 m ) η M (Qt) W t+1. Let γk := 20k n L for some absolute constant c. Suppose that 0 γk < k. Then, for any t, with probability at least 1 exp( 90d) 2e 50k log M, Ht Q 2 ηγkdist(B , Bt). Proof. As before, we may omit the time superscript t in cases where it is clear for notational convenience. The proof is based on the argument in Lemma 5 in (Collins et al., 2021). Again, the main tool is an ϵ-net argument. We first bound qm 2 and wm 2. Bounding qm: With probability at least 1 2e 50k log M, for each m [M], we have that qm 2 = Bt((Bt) B w m + Fm) B w m 2 (Bt(Bt) I)B w m 2 + Bt Fm 2 dist(Bt, B ) w m 2 + Fm 2 kdist(Bt, B ) + δk 1 δk dist(Bt, B ) w m 2 kdist(Bt, B ). Above, we utilized the assumption that w m 2 k, the orthonormality of Bt (which was derived as the orthogonal matrix from a Gram-Schmidt procedure), the assumption that 0 < δk 1/2, as well Lemma A.8 which bounds Fm 2 (for all m) with probability at least 1 2e 50k log M. Bounding wm: Note that for notational convenience, we let wm denote wt+1 m . For each t and every m [M], we have that wt+1 m 2 = (Bt) B w m + Fm 2 w m 2 + Fm 2 w m 2 + δk 1 δk dist(Bt, B ) w m 2 with probability at least 1 2e 50k log M, where again we used Lemma A.8 to handle Fm 2, the assumption that δk < 1/2, and the fact that dist(Bt, B ) 2. For the rest of the proof, we condition on the event E := n qm 2 2 kdist(Bt, B ) and wm 2 3 k m [M] o , which holds with probability at least 1 2e 50k log M. Published as a conference paper at ICLR 2023 ϵ-net argument to bound HQ: Again, note that there exists an 1/4-net Nk of the unit sphere Sk 1 and an 1/4-net Nd of the unit sphere Sd 1 with cardinalities less than or equal to 9k and 9d respectively. Note now that by Equation 4.13 in (Vershynin, 2018), we have j=1 xj i,m(xj i,m) qmw m η 2η max u Nd,v Nk 1 M D xj i,m(xj i,m) qmw m qmw m u, v E = 2η max u Nd,v Nk 1 M h u xj i,m (xj i,m) qmw mv qmw mu, v i Fix now a u Nd and v Nk. Note now that u xj i,m (xj i,m) qmw mv is subexponential with norm less than or equal to qm 2 wm 2 6kdist(Bt, B ), since it is the product of two subgaussian variables u xj i,m and (xj i,m) qmw mv with subgaussian norms bounded by 1 and qm 2 wm 2 respectively. Note also that E h u xj i,m (xj i,m) qmw mv i = E qmw mu, v . Thus, by Bernstein s inequality, carrying on from equation 22, we have that h u xj i,m (xj i,m) qmw mv qmw mu, v i ρ ρ 6kdist(Bt, B ), ρ kdist(Bt, B ) cn L ρ kdist(Bt, B ) where we will choose ρ such that ρ kdist(Bt,B ) 1 to simplify the exponent in the way we did, and c is an absolute constant that may change from line to line. Above, we also used the fact that PM m=1 Lm = n L (recall that L is the total number of samples per agent and there are n agents). Consider the choice ddist(Bt, B ) h u xj i,m (xj i,m) qmw mv qmw mu, v i ρ cn L ρ kdist(Bt, B ) exp( 100d). Taking a union bound over all u Nd and v Nk, it follows then that η 2ρ 9d+k exp( 100d) exp( 90d), where above we used the fact that d k. Published as a conference paper at ICLR 2023 A.4.4 COMBINING EARLIER ARGUMENT: CONVERGENCE OF FEDDAR As seen in Lemma A.8, we require that Lmin := minm [M] Lm to be lower bounded. However, since Lm is a stochastic variable, we are unable to directly lower bound it. Below, we provide a result that converts a lower bound on each client s sample size L (a deterministic quantity we can control) to a high-probability lower bound on Lmin. Lemma A.10. Let Lmin := minm [M] Lm. For any α > 0, suppose that for each m [M], L max 182 log M Pn i=1 πi,m , 16 Pn i=1 πi,m , 2α Pn i=1 πi,m Then, with probability at least 1 exp( 90), Proof. Note that j=1 1(domain(xj i) = m), which is a sum of n L independent random variables bounded between 0 and 1. Moreover, i=1 πi,m L, where πi,m is the probability that a datapoint comes from domain m for client i. Note finally that E[ 1(domain(xj i) = m) 2 ] = πi,m. Hence, by Bernstein s inequality, it follows that for any s > 0, i=1 πi,m L s s2/2 Pn i=1 PL j=1 πi,m + s/3 Since we wish to perform union bound over the M domains, we seek to choose s and L such that s2/2 Pn i=1 PL j=1 πi,m + s/3 exp ( 91 log M) , s2/2 Pn i=1 PL j=1 πi,m + s/3 M exp ( 91 log M) exp ( 90 log M) . To this end, note that we need s2/2 Pn i=1 PL j=1 πi,m + s/3 91 log M s2 2 91 log M j=1 πi,m + s/3 v u u t182 log M i=1 πi,m L + 182 log M 2 + 182 log M Suppose we pick L such that n X i=1 πi,m L 182 log M, Published as a conference paper at ICLR 2023 so that v u u t182 log M i=1 πi,m L + 182 log M 2 + 182 log M i=1 πi,m L. Then, by picking s = 2 p Pn i=1 πi,m L, it follows that s2/2 Pn i=1 PL j=1 πi,m + s/3 exp ( 91 log M) , such that for each m [M], i=1 πi,m L 2 exp( 91 log M). By choosing L such that v u u t i=1 πi,m L 4, it follows that P Li,m Pn i=1 πi,m L exp( 91 log M). The result now follows by choosing L such that it also satisfies Pn i=1 πi,m L for each m. Lemma A.11 (Descent lemma). Define E0 := 1 dist2(B0, B ) and σmax, := σmax 1 and σmin, := σmin 1 M W . Let κ := σmax, σmin, . Consider any iteration t. Suppose that 2, 8E0/(25 5κ2) 2 , (23) where c > 0 is absolute constant. Suppose also that 182 log M Pn i=1 πi,m , 16 Pn i=1 πi,m , 2 (100Ck log M) 1 (min{ 1 2 ,8E0/(25 5κ2)}) 2 Pn i=1 πi,m which by Lemma A.10, ensures that with probability at least 1 e 90, Lt min (100Ck log M) 1 min 1 2, 8E0/(25 5κ2) 2 , (24) where Lt min = minm [M] Lt m denotes the minimum number of samples from any domain at iteration t, and C > 0 is an absolute constant. Then, for any η 1/(4 σ2 max, ), we have dist(Bt+1, B ) (1 ηE0 σmin, /2)1/2dist(Bt, B ), with probability at least 1 e 80. Published as a conference paper at ICLR 2023 Proof. We begin with the observation that W t+1 = W (B ) Bt + F t Bt+1 = Bt η M (Qt) W t+1 Ht Q, where Qt = W t+1(Bt) (W )(B ) , j=1 xj i,m(xj i,m) qm(wt+1 m ) η M (Qt) W t+1. Above Bt+1 denotes the estimate of B before we perform the QR decomposition. We note that the updates for W and B are exactly analogous to the updates for W and B as seen in the proof of Lemma 6 in (Collins et al., 2021). The only two differences are 1. The definitions of F in our paper and (Collins et al., 2021) are slightly different. However, in both cases, F F δk 1 δk dist(B , B) W F for some term δk 1/2 with high probabilities. In our case, this event holds with probability at least 1 2e 50k log M, whilst in (Collins et al., 2021), the event holds with probability at least 1 exp( 110k2 log n). 2. The update for Bt+1 in (Collins et al., 2021) takes the form Bt+1 = Bt η rn(Qt) W t+1 η m A A(Qt) Qt W t+1, where 0 r 1 is a ratio term used in (Collins et al., 2021), and m above represents the number of samples used by each learner in (Collins et al., 2021) (which is different from our use of m as an index over the domains). However, we note that with high probabilities, Ht Q 2 ηγkdist(Bt, B ), η rn m A A(Qt) Qt W t+1 2 ηγkdist(Bt, B ), where the definition of γk in both papers differ but both satisfy the assumption that γk k. Due to these similarities in the updates for W t+1 and Bt+1 with the update in (Collins et al., 2021), the proof of this lemma follows naturally from the proof of Lemma 6 in (Collins et al., 2021), by plugging in η M (Qt) W t+1 in the update for Bt+1 in place of η rn(Qt) W t+1 as in (Collins et al., 2021). In particular, following the same analysis as in (Collins et al., 2021), we see that on the events in Lemma A.8 and Lemma A.9, following the equation immediately after Equation (84) in (Collins et al., 2021), we have dist(Bt, B ) 1 q 1 4η δk (1 δk)2 σmax, 2 1 η σ2 min, E0 + 2η δk (1 δk)2 σ2 max, dist(Bt, B ), where in our case δk = δk + γk. Then, by choosing δk < 16E0/(25 5κ2), (25) it follows that δk < 1/5, and so 1 η σ2 min, E0 + 2η δk (1 δk)2 σ2 max, 1 4η δk (1 δ2 k) σ2 max, 1 ηE0 σ2 min, /2, Published as a conference paper at ICLR 2023 as in equation (85) in (Collins et al., 2021), such that dist(Bt+1, B ) (1 ηE0 σ2 min, /2)1/2dist(Bt, B ). It remains for us to understand what the constraint on δk spelt out in equation 25, and the constraints on δk and γk (in Lemmas A.8 and A.9 respectively) mean in our choice of the sample size L for each agent, and the domain size Lm at each iteration. Observe that we need k log M Lmin 1 δk = δk + γk = 10C k log M Lmin + 20k n L 16E0/(25 5κ2), (28) where c, C > 0 are absolute constants. By choosing Lmin (100Ck log M) 1 min 1 2, 8E0/(25 5κ2) 2 2, 8E0/(25 5κ2) 2 , we ensure that the requirements in equation 26, equation 27 and equation 28 are all satisfied. The final result then follows by applying Lemma A.10. This then yields the following convergence result, which is a more complete statement of A.4. Theorem A.12 (Convergence result for Algorithm 2). Define E0 := 1 dist2(B0, B ) and σmax, := σmax 1 M W and σmin, := σmin 1 M W . Let κ := σmax, Suppose that 2, 8E0/(25 5κ2) , where c > 0 is absolute constant. Suppose also that ( 182 log M Pn i=1 πi,m , 16 Pn i=1 πi,m , 2 (100Ck log M) 1 min{1/2,8E0/(25 5κ2)} Pn i=1 πi,m Then, for any η 1/(4 σ2 max, ), we have dist(Bt+1, B ) (1 ηE0 σmin, /2)1/2dist(Bt, B ), with probability at least 1 e 80. Then for any T and any η 1/(4σ2 max, ), we have dist(Bt, B ) (1 ηE0 σ2 min, /2)T/2dist(B0, B ), (29) with probability at least 1 Te 80. By assuming that σ2 min, > 0, the bound in Theorem 1 decays exponentially. We note that the total number of samples required per client scales with L log(1/ϵ). In addition, in order for the result to be meaningful, we implicitly assume that E0 is close to 1 such that 0 < 1 ηE0 σ2 min < 1. To do so, we note it is possible to choose B0 such that dist(B0, B ) is close enough to 0, with only a logarithmic increase in sample complexity when the number of samples is uniform across the domains. The argument follows the proof of Theorem 3 in (Tripuraneni et al., 2021). Published as a conference paper at ICLR 2023 Theorem A.13. Suppose Assumptions A.1, A.2, A.3 all hold. Suppose also that x0,j i N(0, Id) independently for all i [n]. Suppose each client i sends the server Zi := PL0 j=1(y0,j i )2xj i(xj i) , as well as the integer value of Li, such that the server can compute Z := 1 n L0 Pn i=1 Zi. Then, the server computes UDU rank-k SVD (Z), and sets B0 := U. Let j=1 w m(i,j)(w m(i,j)) , where m(i, j) denotes the sample of the j-th sample from the i-th client. Let σmin, := σmin( Λ), and let σmax, := σmax( Λ). Suppose that L0 cpolylog(d, n L0)σmax, dk2/(nσ2 min, ). Then, with probability at least 1 (n L0) 100, we have that dist(B0, B )2 O σmax, k2d σ2 min, n L0 In particular, when the number of samples is uniform across the domains, we have that dist(B0, B )2 O κ4k2d where we recall that κ := σmax, / σmin, , and σmax, := σmax M W , σmin, := σmin Proof. We omit the proof since it is a slight variant of Theorem 3 in (Tripuraneni et al., 2021). For completeness, note that in the case when the number of samples is uniform across the domains, some algebra shows that dist(B0, B )2 O κ2k2d σ2 min, n L0 However, since k/4M W 2 F k M σ2 max, , we have that 1 σ2 min, = κ2 1 σ2max, 4κ2, which proves the last statement in the theorem. Published as a conference paper at ICLR 2023 B ADDITIONAL EXPERIMENTAL RESULTS B.1 EXPERIMENTS ON FAIRFACE DATASET FOR GENDER CLASSIFICATION Table 4: Min, max and average test accuracy of gender classification across 7 domains (race groups) on Fair Face with number of clients n = 5, number of samples at each client Li = 500. Task Method α = 0.1 α = 0.5 α = 1 α = 100 Max Min Avg Max Min Avg Max Min Avg Max Min Avg Fed Avg 92.0 71.7 83.9 89.8 77.6 84.5 91.0 77.4 84.2 90.5 77.1 84.7 Fed Avg + Multi-head 90.2 48.7 78.9 89.2 77.8 84.1 91.6 76.8 83.9 91.1 77.5 84.5 Fed DAR-WA 89.8 53.4 80.9 91.5 76.7 84.3 91.2 76.1 84.3 90.0 76.8 84.1 Fed DAR-SA 92.2 73.4 85.1 91.3 78.1 85.2 91.4 78.2 85.1 92.2 78.1 85.6 We also conduct experiments for gender classification on Fair Face with the same settings. The best representation dimension is k = 2 for this task, probably due to the smaller diversity across the domains. We can see that the results shown in Table 4 have a similar trend as the results in Table 5.2. B.2 EXPERIMENTS ON DIGITS DATASET Table 5: Min, max and average test accuracy of digits classification across 5 domains with number of clients n = 5, number of samples at each client Li = 500. Method α = 0.1 α = 0.5 α = 1 α = 100 Max Min Avg Max Min Avg Max Min Avg Max Min Avg Fed Avg 97.1 60.7 80.6 97.2 64.3 81.7 96.1 74.8 85.2 96.8 71.0 85.1 Fed Avg + Multi-head 94.3 26.5 55.9 94.3 44.8 68.3 94.1 56.7 74.6 95.0 52.3 74.5 Fed DAR-WA 97.3 52.3 79.8 97.3 64.7 83.1 96.6 74.5 86.3 97.1 70.6 86.3 We perform additional experiments on the digits dataset with five data domains with feature shift (Li et al., 2021b). Details are described in the following paragraphs. From Table 5, we can see that Fed DAR-WA outperforms Fed Avg consistently except in the case where domain distributions are extremely heterogeneous (α = 0.1). In this case, each client tends to have data from only one domain. It is difficult for the proposed method to learn a good domain-specific head for the domain with the most different data (more obvious feature shift) under this circumstance. For other levels of heterogeneity, although the min and max domain accuracies are similar between Fed Avg and Fed DAR-WA, the average accuracies are improved as a result of the domain-wise personalized model. On the other hand, without an alternative update of the head and representation, Fed Avg + Multi-head will overfit quickly. We don t include the results of Fed DAR-SA here because using representation dimension k 64 causes numerical instability during head aggregation and failure to converge. While using representation dimension k 32 leads to lower accuracy. Datasets. We use the same digits dataset containing five different data domains as (Li et al., 2021b). Specifically, we use SVHN (Netzer et al., 2011), USPS (Hull, 1994), Synth Digits (Ganin & Lempitsky, 2015), MNIST-M (Ganin & Lempitsky, 2015) and MNIST (Le Cun et al., 1998) as five data domains. Similarity to the experiments on Fair Face datraset, the training data is divided into n clients without duplication. Each client has a domain distribution πi Dir(αp) sampled from a Dirichlet distribution. Implementation Details. We adapt the codebase from (Li et al., 2021b). A 6-layer CNN with 3 convolutional layers and 3 fully-connected layers is used, with the last layer as domain-specific head. We use SGD optimizer with learning rate 10 2 and cross-entropy loss. The batch size is set to 32, and the total communication rounds is set to 100. For each method, we first train the model for 10 rounds with 1 local epoch using Fed Avg as warmup. The accuracy shown is the average over the last ten communication rounds. We repeat experiment for each setting three times with different random seeds and report the averages. Published as a conference paper at ICLR 2023 B.3 FURTHER EXPERIMENTAL DETAILS B.3.1 SYNTHETIC DATA For the synthetic data experiments, we adapt the code from (Collins et al., 2021) and follow a similar protocol. The ground-truth matrices W RM k and B Rd k are generated following the same way as (Collins et al., 2021) by sampling each element from i.i.d. standard normal distribution and taking the QR factorization. The same L samples are used for each client during the whole training process. Test samples are generated in the same way as the traning samples but without noise. For all the methods, models are initalized with ramdom Gaussian samples. We set α = 0.4 for experiments in Figure 5.2. B.3.2 REAL DATA WITH CONTROLLED DISTRIBUTION Implementation details. We use Imagenet(Deng et al., 2009) pre-trained Res Net-34 (He et al., 2016) for all experiments on this dataset. All the methods are trained for T = 100 communication rounds, with 20 rounds of Fed Avg as warmup. For Fed DAR-WA and Fed DAR-SA, 5 epochs of local updates are executed for both heads and representation at each round. For the baselines, 5 epochs of local updates are executed at each round for fair comparison. We use Adam optimizer with a learning rate of 1 10 4 for the first 60 rounds and 1 10 5 for the last 40 rounds. The images are resized to 224 224 with only random horizontal flip for augmentation. The learning rate and the number of local epochs is tuned by grid search with a fixed batch size of 64. We tuned the projection dimension k for Fed DAR-SA among {4,8,16,32,64} with α = 1.0 and used k = 8 for all other α. Our evaluation metrics are the classification accuracy on the whole validation set of Fair Face for each race group. We don t have extra local validation set to each client since we assume the data distribution within each domain is consistent across the clients. The numbers reported are the average over the final 10 rounds of communication following the standard practice in (Collins et al., 2021), and the average of three independent runs with different random seeds. B.3.3 REAL DATA WITH REAL-WORLD DATA DISTRIBUTION Dataset details. The detailed statistics of the partial EXAM dataset is summarized in Table 6. The "Other" category includes American Indian or Alaska native, native Hawaiian or other Pacific islander and patients with more than one race or unknown race. HFO % means the percentage of cases with positive labels (receiving oxygen therapy higher or equal to high-flow oxygen with 72 hours). Table 6: Data summary of the partial EXAM dataset used in our study. Site White Black Asian Latino Other HFO % Site-1 59.6% 10.0% 3.4% 2.0% 24.9% 12.4% Site-2 75.0% 11.1% 2.8% 0.6% 10.5% 9.1% Site-3 46.5% 26.3% 4.2% 7.0% 16.0% 9.6% Site-4 71.4% 6.3% 4.2% 0.8% 17.2% 11.4% Site-5 44.0% 28.4% 1.6% 6.3% 19.8% 9.9% Site-6 0.0% 0.0% 100.0% 0.0% 0.0% 18.8% Implementation details. We apply 5-fold cross validation. The input of the model is one chest x-ray image resized to 224x224 paired with a 22-dimensional electronic health record(EHR) data, the representation dimension is 278 if it is not projected. All the models are trained for T = 20 communication rounds with Adam optimizer and a learning rate of 1 10 4. For each round we do 1 local epoch for all the methods. For all the methods, the models are initialized with the same pretrained model as in (Dayan et al., 2021) without any warmup. For Fed DAR-SA and Fed DARWA, we execute 5 epochs of update for heads on each round, and set representation dimension k = 16 for Fed DAR-SA. Hyperparameters including learning rate, number of epochs for head update and representation dimension are tuned through grid search with a fixed batch size of 36. For Fed Rep,Fed DARand Fed Per. For LG-Fed Avg, we treated the last fully-connected layer as the global parameters and all other layers as local representation. For Fed Min Max, multiple local iterations are executed during each round instead of one step of GD for reasonable comparison. For Fed Prox Published as a conference paper at ICLR 2023 we tuned µ among {0.05, 0.1, 0.25, 0.5} and used µ = 0.1. For the fine-tuning methods, we only fine-tune the global trained model locally with Adam optimizer and learning rate of 5e 5 for 1 epoch since more epochs of fine-tuning leads to worse results. The models are evaluated by aggregating predictions on the local validation sets then calculating the area under curve (AUC) for each domain. The average AUCs on local validation set of clients are also reported. The AUC shown is first averaged over the last five communication rounds, and then averaged over five runs of 5-fold cross validation. C DISCUSSION ON COMMUNICATION AND PRIVACY Communication For Fed DAR-WA, the only communication overhead comes from the extra parameters of multiple heads for different domains, which only slightly increase the communication cost. For Fed DAR-SA, we need to send a Hessian with k2 N 2 parameters from each client to the server at each round. This might be costly when both representation dimension k and output dimension N are large. However, compared to sending millions of parameters of neural network, the extra communication cost is acceptable. Privacy For the Fed DAR-WA, there is no extra parameters shared compared to Fed Avg. So there is no additional privacy risk introduced. Privacy techniques like homomorphic encryption (Cheon et al., 2017) or differential privacy (Mc Mahan et al., 2017b; Kairouz et al., 2021) that apply to Fed Avg also works for Fed DAR-WA. In fact, the multi-head design of different domains makes it harder to perform gradient based attack(Zhu et al., 2019) targeting our method. Because the attacker need to first figure out which domain the sample comes from. For Fed DAR-SA, the only extra parameters shared is the Hessian matrices, which are aggregated results from all the local data. Recovering the information for a specific sample from Hessian is extremely difficult. Under the worst circumstance, what the attacker can recover from the Hessian is the label and the features at last layer, which hardly ease the difficulty of recovering original input.