# byzantine_resilient_and_fast_federated_fewshot_learning__2c2f1b77.pdf Byzantine Resilient and Fast Federated Few-Shot Learning Ankit Pratap Singh 1 Namrata Vaswani 1 Abstract This work introduces a Byzantine resilient solution for learning low-dimensional linear representation. Our main contribution is the development of a provably Byzantine-resilient Alt GDmin algorithm for solving this problem in a federated setting. We argue that our solution is sample-efficient, fast, and communicationefficient. In solving this problem, we also introduce a novel secure solution to the federated subspace learning meta-problem that occurs in many different applications. 1. Introduction Multi-task representation learning refers to the problem of jointly estimating the model parameters for a set of related tasks. This is typically done by learning a common representation for all of their source vectors (feature vectors). This learned representation can then be used for solving the meta-learning or learning-to-learn problem: learning model parameters in a data-scarce environment. This strategy is referred to as few-shot learning. In recent work (Du, Hu, Kakade, Lee, & Lei, 2020), a very interesting low-dimensional linear representation was introduced and the corresponding low rank matrix learning optimization problem was defined. However, (Du et al., 2020) assumed that this optimization problem (see eq. (1)), which is non-convex, can be correctly solved. It is mentioned that it should be possible to solve it by solving a nuclear norm based convex relaxation of it. However, there are no known guarantees to ensure that the solution to the relaxation is indeed also a solution of the original problem. Moreover, convex relaxations are known to be very slow to solve (compared with direct iterative solutions) (Jain, Kar, et al., 2017; Netrapalli, Jain, & Sanghavi, 2013): these need order 1/ ϵ number of iterations to obtain an ϵ accurate solution. In follow-up work, (Tripuraneni, Jin, & Jordan, 1Department of Electrical and Computer Engineering, Iowa State University, Ames IA, USA. Correspondence to: Ankit Pratap Singh . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). 2021) studied a special case in which all the source vectors for the different tasks are the same. It introduced a method of moments estimator that is faster, but needs many more samples; sample complexity grows as 1/ϵ2. In interesting parallel works (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021; Collins, Hassani, Mokhtari, & Shakkottai, 2021), a fast and communication-efficient GDbased algorithm, that was referred to as Alternating GD and Minimization (Alt GDmin) and Fed Rep respectively, was introduced for solving the mathematical problem given in (1), when the available number of training samples per task is much lesser than the regression vector length. Followup work (Vaswani, 2024) improved the guarantees for Alt GDmin while also simplifying the proof. Alt GDmin and Fed Rep algorithms are identical except for the initialization step. Alt GDmin uses a better initialization and hence also has a better sample complexity by a factor of r. The latter (Fed Rep) paper referred to the problem of (1) as multi-task linear representation learning. The former (Alt GDmin) paper used federated sketching, dynamic MRI (Babu, Lingala, & Vaswani, 2023) as motivating applications. It also solved the phaseless generalization of (1) called low rank phase retrieval. In older work (Nayer & Vaswani, 2021; Nayer, Narayanamurthy, & Vaswani, 2020, 2019), an alternating minimization (Alt Min) solution to this problem was developed and analyzed as well. Since (1) is a special case of this more general problem, this Alt Min solution also solves (1). All these works study the centralized setting or the attack-free federated setting. Other somewhat related works include (Shen, Ye, Kang, Hassani, & Shokri, 2023; Tziotis, Shen, Pedarsani, Hassani, & Mokhtari, 2022). A longer version of the mathematical problem being solved in this work (Byzantine resilient low rank column-wise compressive sensing) is at (Singh & Vaswani, 2024). 1.1. Contributions We adapt the alt GDmin algorithm described above to show how it can solve the multi-task linear representation learning and few shot learning problems. Our main contribution is the development of a provably Byzantine-resilient Alt GDmin-based solution for solving this problem in a federated setting. Our solution is communication-efficient along with being fast and sample-efficient. In this setting, resilience to adversarial attacks on some nodes is an impor- Byzantine Resilient and Fast Federated Few-Shot Learning tant requirement. The most general attack is the Byzantine attack. For this, the attacking nodes can collude; and all the attacking nodes know the outputs of all the nodes, the algorithm being implemented by the center, and the algorithm parameters. In solving the above problem, we introduce a novel solution approach, called Subspace Median, for combining subspace estimates from multiple federated nodes when some of them can be malicious. This approach and its guarantee (Lemma 3.1) are of independent interest for developing a secure solution to the federated subspace learning metaproblem that occurs in many applications (online) PCA, subspace tracking, initializing many sparse recovery, low rank matrix recovery, or phase retrieval problems. 1.2. Related Work Few-shot learning is applied across various tasks such as image classification (Vinyals, Blundell, Lillicrap, Wierstra, et al., 2016), sentiment analysis from short texts (Yu et al., 2018), and object recognition (Fei-Fei, Fergus, & Perona, 2006), with much of the focus on practical experimentation over theoretical development (Snell, Swersky, & Zemel, 2017; Ravi & Larochelle, 2016; Sung et al., 2018; Boudiaf et al., 2020). Representation learning, a significant method within this field, has been highlighted in several studies (Sun, Shrivastava, Singh, & Gupta, 2017; Goyal, Mahajan, Gupta, & Misra, 2019), though they often fall short of providing algorithmic guarantees for provably solving the representation learning problem (Du et al., 2020; Baxter, 2000; Maurer, Pontil, & Romera-Paredes, 2016; Tripuraneni et al., 2021; Tripuraneni, Jordan, & Jin, 2020; Y. Li, Ildiz, Papailiopoulos, & Oymak, 2023). Recent work by (Collins et al., 2021) and (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021; Vaswani, 2024) developed a provable algorithm to solve the low-dimensional linear representation learning problem, although they do not consider Byzantine attacks. There are other line of works which extends the low-dimensional linear representation learning problem (Shen et al., 2023), which focuses on Differential Privacy. The algorithm CENTAUR, presented in their work, aligns with the server and client procedures outlined in the study by (Collins et al., 2021), with the notable addition of the Gaussian mechanism. The work presented in (Tziotis et al., 2022) addresses the challenge of stragglers. To combat the straggler effect, the paper introduces a novel sampling mechanism that utilizes a doubling strategy. Geometric Median is one of the aggregation method to handle Byzantine attacks. (Chen, Su, & Xu, 2017) develops non-asymptotic analysis in stochastic gradient descent utilized the geometric median of means, giving convergence guarantees under specific conditions. Follow-up work uses coordinate-wise mean and trimmed-mean estimators (Yin, Chen, Kannan, & Bartlett, 2018) but with assumption of bounded variance and coordinate-wise bounded skewness (or coordinate-wise sub-exponential) on the gradient distribution. (Alistarh, Allen-Zhu, & Li, 2018; Allen Zhu, Ebrahimian, Li, & Alistarh, 2020) provided nonasymptotic guarantees for Byzantine resilient stochastic gradient descent, assuming a consistent set of Byzantine nodes across iterations. Some studies have explored heterogeneous data distributions, establishing results within bounds of heterogeneity (Pillutla, Kakade, & Harchaoui, 2019; Data & Diggavi, 2021; L. Li, Xu, Chen, Giannakis, & Ling, 2019; Ghosh, Hong, Yin, & Ramchandran, 2019). While (Regatti, Chen, & Gupta, 2022; Lu, Li, Chen, & Ma, 2022; Cao, Fang, Liu, & Gong, 2020; Cao & Lai, 2019; Xie, Koyejo, & Gupta, 2019) use detection methods to manage heterogeneous gradients with a trusted dataset at central server. 1.3. Problem Set up First consider the centralized setting. Suppose that there are q source tasks, each task k [q] associated with a distribution over the input-output space X Y, where X ℜn and Y ℜ. The aim is to learn prediction functions for all tasks simultaneously, leveraging a shared representation φ : X Z that maps inputs to a feature space Z. We let the representation function class be Low-Dimensional Linear Representations i.e., {x 7 U T x|U ℜn r} (Du et al., 2020). An example is the two-layer Re LU neural network. The goal is to find the optimal representation φ , represented by U and the true linear predictors b k for all tasks k [q] to minimize the difference between the predicted and actual outputs. Arranging the m input features for task k as rows of an m n matrix Xk, and the outputs in an m 1 vector yk, we have the following model Y = [y1, y2, ..., yq] := [X1U b 1, ..., Xq U b q] + V where V is the modeling error that is assumed to be i.i.d. zero mean Gaussian with variance σ2 v. We have assumed an r-dimensional linear model for the regression coefficients, i.e., θ k = U b k, with r min(n, q). In other words, the n q regression coefficients matrix Θ = U B is rank r. Our goal is to learn the column span of the n r matrix U (and in the process also learn Θ ), from the m q matrix Y . We assume that all the feature vectors for all the tasks are i.i.d. standard Gaussian, i.e., all the Xks are i.i.d. and have i.i.d. standard Gaussian entries. Solving this problem requires solving min U ℜn r B ℜr q k=1 yk Xk U bk 2 (1) In the federated setting, we assume that there are a total of L nodes. Each observes a different disjoint subset ( em = Byzantine Resilient and Fast Federated Few-Shot Learning m/L) of rows of Y . Denoting the set of rows observed at node ℓby Sℓ, this means that Sℓs are disjoint and L ℓ=1Sℓ= [q]. At most τL nodes can be Byzantine with τ < 0.4. The nodes can only communicate with the center. In this work, all U matrices are n r and are used to denote the subspaces spanned by their columns. We use . to denote the (induced) ℓ2 norm and . F for the Frobenius norm. For U1, U2 with orthonormal columns, we use SD2(U1, U2) := (I U1U1 )U2 or SDF (U1, U2) := (I U1U1 )U2 F to quantify the Subspace Distance (SD). Clearly SDF (U1, U2) r SD2(U1, U2). 2. Centralized Multi-task Representation Learning and Few Shot Learning Below, we first give the Alt GDmin algorithm from (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021) to learn U . This is also similar to the Fed Rep algorithm of (Collins et al., 2021), with the difference only being that the Alt GDmin initialization is better (has a better sample complexity). Next, we give details about few-shot learning. 2.1. Multi-task Linear Representation Learning Recall that the goal is to minimize f(U, B) := Pq k=1 yk Xk Ubk 2 where B = [b1, ..., bq]. Alt GDmin (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021; Collins et al., 2021; Vaswani, 2024) proceeds as follows. We first initialize U as explained below; this is needed since the our optimization problem is clearly non-convex. After this, at each iteration, we alternatively update U and B as follows: (1) Keeping U fixed, update B by solving min B f(U, B) = min B Pq k=1 yk Xk Ubk 2. (2) Keeping B fixed, update U by a GD step, followed by orthonormalizing its columns: U + QR(U η Uf(U, B))). Here Uf(U, B) = P k [q] X k (Xk Ubk yk)b k , η is the step-size for GD. We initialize U by (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021) computing the top r singular vectors of k X k (yk)trunce k , ytrunc := (y 1|y| α) Here α := 9κ2µ2 P k yk 2/mq. Here and below, ytrunc refers to a truncated version of the vector y obtained by zeroing out entries of y with magnitude larger than α (the notation 1z α returns a 1-0 vector with 1 where zj < α and zero everywhere else, and z1 z2 is the Hadamard product (.* operation in MATLAB)). The algorithm is summarized in Algorithm 1. We can show the following. Theorem 2.1 ((Vaswani, 2024)). Assume σ2 v = 0 and that maxk b k µ p r/qσ1(Θ ) for a constant µ 1 (incoherence of right singular vectors of Θ ). Let κ denote the ratio of the first to the r-th singular value of Θ . Consider Algorithm 1 with η = 0.4/mσ 1 2 and T = Cκ2 log(1/ϵ). If mq Cκ6µ2(n + q)r(κ2r + log(1/ϵ)) and m C max(log n, log q, r) log(1/ϵ), then, with probability (w.p.) at least 1 n 10,SD2(U, U ) ϵ and θk θ k ϵ θ k for all k [q]. The time cost is mqnr T = Cκ2mqnr log(1/ϵ). The communication cost is nr per node per iteration. This result shows that, as long as the total number of samples per task, m, is roughly order nr2/q, the learning error decays exponentially with iterations even with a stepsize η being a numerical constant (fast decay). Thus, after T = Cκ2 log(1/ϵ) iterations, SD(U, U ) ϵ, i.e. the low-dimensional subspace is accurately learned. Treating κ, µ as numerical constants and assuming n q, notice that the Alt GDmin sample complexity is mq nr max(r, log(1/ϵ)). On the other hand, Fed Rep (Collins et al., 2021) needs to assume mq nr2 max(r, log(1/ϵ)) which is worse by a factor of r. In fact this complexity is comparable to that for the Alt Min solution from (Nayer & Vaswani, 2021) that solved this problem and its LRPR generalization. The older result of (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021) for Alt GDmin needed mq nr2 log(1/ϵ). This is worse by a factor of max(1, r/ log(1/ϵ)). The Fed Rep guarantee is worse because its initialization involves computing U0 as top r singular vectors of the matrix P ki y2 kixkix ki1(y2 ki (9κ2µ2 P ki y2 ki/mq)), and its analysis of the GD step is not as tight as can be (similar to that of (Nayer & Vaswani, 2023, on ar Xiv since Feb. 2021)). The advantage of the result of (Collins et al., 2021) was (i) a slightly better dependence κ, and (ii) it studied the low rank column-wise sensing problem in the σ2 v = 0 setting, while the result of (Vaswani, 2024) assumes σ2 v = 0. As we explain in the remark given next, this result can easily extend to the σ2 v = 0 setting as well with no change to its sample complexity. Remark 2.2 (Theorem 2.1 with σ2 v = 0). Assume everything from Theorem 1 and that 0 < σ2 v c Θ 2 F q . Let ϵnoise := Cqκ2 σ2 v σ 1 2 . Then, SD2(U, U ) max(ϵ, ϵnoise). In words, the error decays exponentially until it reaches the (normalized) noise-level , but saturates after that. 2.2. Few-Shot Learning Few-shot learning refers to learning in data-scarce environments (Du et al., 2020). Once an estimate U for the true representation U is obtained, the problem simplifies to learning a predictor function bk : ℜr ℜdefined on ℜr Byzantine Resilient and Fast Federated Few-Shot Learning and specialized for each task k. Now, each source task can easily compute the local predictor b k using the available samples, as r m. We want to bound the excess risk of the learned predictor on target task new. We are given mnew input, output training data pairs arranged into an m n matrix Xnew, and an m 1 vector ynew and we need to solve the regression problem. However, this is data-scarce setting, i.e., mnew n and consequently without the low-dimensional linear representation, it is impossible to solve the regression problem. However, using the learned U, we can easy learn an r-dimensional vector of regression coefficients as long as mnew > r. Excess risk on the learned predictor is given by x newθ new + vnew x new Ubnew , where θ new = U b new. We compute bnew as bnew = (Xnew U) ynew. Here, U is the final estimate from the Alt GDmin algorithm described above. M := (M M) 1M . We can prove the following for it. We have the following bound on the expected value of the excess risk (ER) for the few-shot learning task. Recall from (Du et al., 2020) that E[ER(U, bnew] = E[(y ˆy)2] where y = θ new x + v and we predict it as ˆy = b new U x with bnew as given by the last step of Algorithm 1 and U is the output of its learning representation step. Corollary 2.3. Let U be the final output of the learning steps of Algorithm 1. If mnew C max(r, log q, log n), then, the excess risk E[ER(U, bnew] = θ Ubnew 2 + σ2 v C max(σ2 v, ϵ b new 2). Notice that, with just order r samples, we are able to learn the regression coefficients for n-dimensional features. 3. Resilient Federated Multi-Task and Few-Shot Learning Recall the federated setting problem from Sec. 1.3: there a total of L federated nodes and we assume that at most τL of them may be Byzantine with τ < 0.4. Denote the set of good (non-Byzantine) nodes by Jgood. Equivalently, this means that |Jgood| > (1 τ)L. We develop a solution approach for making Alt GDmin Byzantine resilient that relies on the geometric median (GM). The most challenging part in doing this is modifying the initialization step. For the rest of the algorithm, we can borrow ideas from the existing extensive literature on Byzantine resilient GD discussed earlier. One popular approach in this area is to replace the summation in the gradient computation step by a median for vector-valued quantities. A well-studied one is the geometric median (GM) (Minsker, 2015; Chen et al., 2017), which we will use. The minimization step for update of columns of B can Algorithm 1 Few-Shot Learning via alt GDmin. Let M := (M M) 1M . 1: Input: yk, Xk, k [q] 2: Parameters: GD step size, η; Number of iterations, T 3: Sample-split: Partition the data into 2T + 1 equalsized disjoint sets: y(τ) k , X(τ) k , τ = 0, 1, . . . 2T. Learning Representation: 4: Initialization: 5: set α 9κ2µ2 1 6: Using yk y(0) k , Xk X(0) k , 7: set yk,trunc(α) yk,trnc := trunc(yk, α), 8: set Θ0 (1/m) X k [q] X k yk,trunc(α)e k 9: set U0 top-r-singular-vectors of Θ0 10: GDmin iterations: 11: for t = 1 to T do 12: Let U Ut 1. 13: Using yk y(t) k , Xk X(t) k , 14: set bk (Xk U) yk, θk Ubk for all k [q] 15: Using yk y(T +t) k , Xk X(T +t) k , compute 16: set Uf(U, B) = P k X k (Xk Ubk yk)b k 17: set ˆU + U (η/m) Uf(U, Bt). 18: compute ˆU + QR = U +R+. 19: Set Ut U +. 20: end for Few-shot Learning: Prediction on new source 21: bnew (Xnew U) ynew 22: θnew Ubnew be done locally at the nodes. These are also used only in the local partial gradient computation and hence never need to be transmitted to the center. We should mention though that the analysis of the GD step is not a direct extension of existing ideas because of the important differences between our problem and most standard problems. In our problem, the GD step is not a standard GD or projected GD step for a given cost function. For L data vectors, z1, z2, . . . , z L, the geometric median (GM) is defined as zgm = minz PL ℓ=1 zℓ z . Here and below, . with a subscript denotes the l2 norm. The GM cannot be computed in closed form but various algorithms exist to accurately approximate it. 3.1. GM-based Resilient Spectral Initialization: Subspace Median and Subspace Median of Means This consists of two steps. First a resilient estimate of the truncation threshold α = C mq P i y2 ki needs to be computed. For this, we use the scalar median of means of the partial estimates computed by each node. Next, we need to Byzantine Resilient and Fast Federated Few-Shot Learning compute U0 which is the matrix of top r left singular vectors of Θ0. Node ℓhas data to compute the n q matrix (Θ0)ℓ, defined as k=1 (Xk)ℓ ((yk)ℓ)trunce k , (2) Observe that Θ0 = P ℓ(Θ0)ℓ/L. If all nodes were good, we would use this fact to implement the federated power method (PM) for this case: starting with a random initialization U, this involves iterating the following: compute V := P ℓVℓ/L (where Vℓ= (Θ0)ℓ U) in a federated fashion, followed by computing U + = P ℓ(Θ0)ℓV /L in a federated fashion, and then obtain U + = QR( U +) at the center. To deal with Byzantine attacks, the most obvious solution is to replace the averaging at the center by the GM. However, this works with high probability only if all the (Θ0)ℓ s are extremely accurate estimates of Θ 1. This further implies that its required sample complexity is very large. We provide a detailed discussion of this fact for the simpler PCA problem in Sec. 4 and Table 1. Subspace-Median. Since the GM is defined for quantities whose distance can be measured using the vector l2 norm (equivalently, matrix Frobenius norm), it cannot be directly used for subspaces (or their basis matrices): these do not lie in on a Euclidean space (but instead on the Stiefel manifold). To understand this simply, notice that U, U specify the same subspace even though U ( U) F = 2 r = 0. Notice though that the Frobenius norm between the projection matrices of two subspaces is also a measure of subspace distance: PU PU F = 2SDF (U, U ) (Chen, Chi, Fan, Ma, et al., 2021, Lemma 2.5). Here PU := UU is the projection matrix for subspace U (assumes U has orthonormal columns). We use this idea to develop a simple but useful approach called the Subspace Median : Node ℓcomputes ˆUℓas the top r singular vectors of the matrix (Θ0)ℓthat it has data for, and sends it to the center. If node ℓis good, then ˆUℓalready has orthonormal columns; however if the node is Byzantine, then it is not. The center first orthonormalizes the columns of all the received ˆUℓ: 1The reason for this is that it computes the GM of the node outputs Vℓ= (Θ0) ℓU at each iteration including the first one. At the first iteration, U is a randomly generated matrix and thus, w.h.p., this is a bad approximation of the desired subspace span(U ). Consequently, unless the various (Θ0)ℓ s are very close approximations of Θ , the different Vℓ s are likely to be bad approximations of span(B ). In particular, this means that the estimates at the different nodes may be quite different even for all the good nodes. As a result, their GM is unable to distinguish between the good and Byzantine ones, and, there is a good chance it approximates the Byzantine one(s). A similar argument can be repeated for Uℓs and so on. Thus, unless all the (Θ0)ℓ s are very close approximations of Θ (and hence very similar), there is a good chance that the subspace estimates do not improve over iterations. Uℓ= QR( ˆUℓ) for all ℓ [L]. It then computes the projection matrices PUℓ:= UℓU ℓ, ℓ [L], followed by vectorizing them, computing their GM, and then converting the GM into a matrix. Denote this by Pgm. Finally, the center finds the ℓfor which PUℓis closest to Pgm in Frobenius norm and outputs the corresponding Uℓ. Denote this Uℓby Uout We can show the following for this estimator Lemma 3.1. (Subspace Median) Suppose that |Jgood| (1 τ)L for a τ < 0.4. If minℓ Jgood Pr(SDF (Uℓ, U ) δ) 1 p. Then, with probability at least 1 c0 exp(Lψ(0.4 τ, p)), SDF (Uout, U ) 23δ. Here ψ(a, b) := (1 a) log 1 a 1 b + a log a b is the binary KL divergence. Subspace Median of Means. A median-based estimator can be robust to almost 50% outliers (here Byzantine attacks), but, as is well known, the use of median also wastes samples. In our context, this means that the estimate of each node needs to be accurate enough. If the maximum number of Byzantine nodes is known to be much lesser than 50%, a better approach is to use the median of means (Mo M) estimator. We explain how to develop this for our problem. For a parameter L L, we would like to form L mini-batches of ρ = L/ L nodes; w.l.o.g. ρ is an integer. For the ℓ-th node in the ϑ-th mini-batch we use the short form notation (ϑ, ℓ) = (ϑ 1)ρ + ℓ, for ℓ [ρ]. In our setting, combining samples means combining the rows of (Xk)ℓand (yk)ℓfor ρ nodes to obtain (Θ0)(ϑ) with k-th column given by Pρ ℓ=1(Xk) (ϑ,ℓ)(yk,trunc)(ϑ,ℓ)/ρ. To compute this in a communication-efficient and private fashion, we use a federated power method for each of the L mini-batches. The output of each of these power methods is U(ϑ), ϑ [ L]. Then we do subspacemedian on U(ϑ), ϑ [ L] to obtain the final subspace estimate Uout. To explain the federation details simply, we explain them for ϑ = 1. The power method needs to federate U QR((Θ0)(1)(Θ0) (1)U) = QR(Pρ ℓ =1(Θ0)ℓ (Pρ ℓ=1(Θ0) ℓU)). This needs two steps of information exchange between the nodes and center at each power method iteration. In the first step, we compute V = P ℓ [ρ](Θ0)ℓ U, and in the second one we compute U = P ℓ [ρ](Θ0)ℓV , followed by its QR decomposition. We summarize the complete algorithm in Algorithm 2. Guarantee. We can prove the following. It needs to assume that the same set of τL nodes are Byzantine for all the power method iterations needed for the initialization step2. Theorem 3.2 (Initialization via Subs-Mo M). Assume σ2 v = 0 and that maxk b k µ p r/qσ1(Θ ) for a constant µ 1. Consider Algorithm 2 with Tgm = 2This can be relaxed if we instead assume that a much tighter bound on the number of bad nodes per iteration. Byzantine Resilient and Fast Federated Few-Shot Learning C log( Lr δ0 ), Tpow = Cκ2 log( n δ0 ). Assume that the set of Byzantine nodes remains fixed for all iterations in this algorithm and is of size at most τL with τ < 0.4 L/L. If emq C L L κ6µ2(n + q)r2/δ2 0, then Then, w.p. at least 1 c0 exp( Lψ(0.4 τ, n 10 + exp( c(n + q)))) L exp( c emqδ2 0/r2κ4) SDF (U , Uout) δ0 The communication cost per node is order nr log( n δ0 ). The computational cost at any node is order nqr log( n δ0 ) while that at the center is n2 L log3( Lr/δ0). The extension of the above result for the σ2 v = 0 case will be straightforward and can be proved using the same ideas as those used for Remark 2.2. Proof. This follows by using Lemma 3.1 along with the Davis Kahan sin Θ theorem and concentration bounds from (Vershynin, 2018) applied to analyze the output of each node. We apply the latter two to Φ(ϑ) = Pρ ℓ=1(Θ0)(ϑ,ℓ)(Θ0) (ϑ,ℓ)/ρ and Φ = E[(Θ0)ℓ|α]E[(Θ0)ℓ|α] for ϑ [ L]. 3.2. GM-based Resilient Federated GDmin Iterations We can make the alt GDmin iterations resilient as follows. In the minimization step, each node computes its own estimate (bk)ℓof b k as follows: (bk)ℓ= ((Xk)ℓU) (yk)ℓ, k [q] Each node then uses this to compute its estimate of the gradient w.r.t. U as fℓ= P k Sℓ(Xk) ℓ((Xk)ℓU(bk)ℓ (yk)ℓ)(bk) ℓ. The center receives the gradients from the different nodes, computes their GM and uses this for the projected GD step. Since the gradient norms are not bounded, the GM computation needs to be preceded by the thresholding step. To improve sample complexity (while reducing Byzantine tolerance), we can replace GM of the gradients by their GM of means: form L batches of size ρ = L/ L each, compute the mean gradient within each batch, compute the GM of the L mean gradients. Use appropriate scaling. We summarize the GMo M algorithm in Algorithm 3. The GM case corresponds to L = L. Given a good enough initialization, a small enough fraction of Byzantine nodes, enough samples emq at each node at each iteration, we can prove the following for the GD iterations. Lemma 3.3. (Alt GDmin-Subs Mo M: Error Decay) Consider Algorithm 3 with sample-splitting, and with stepsize η 0.5/σ 1 2. If, at each iteration t, emq C1κ4µ2(n + r)r2( L/L), em > C2 max(log q, log n); Algorithm 2 Byz-Alt GDmin-Learn: Initialization step. 1: Input: Batch ϑ : {(Xk)ℓ, Yℓ, k [q]}, ℓ [L] 2: Parameters: Tpow, Tgm, 3: Nodes ℓ= 1, ..., L 4: Compute αℓ C e mq P k (yk)ℓ 2, with C = 9κ2µ2. 5: Central Server 6: α Median{α(ϑ)} L ϑ=1, where α(ϑ) = Pρ ℓ=1 α(ϑ,ℓ)/ρ 7: Central Server 8: Let U0 = Urand where Urand is an n r matrix with i.i.d standard Gaussian entries. 9: for τ [Tpow] do 10: Nodes ℓ= 1, ..., L 11: Compute Vℓ (Θ0) ℓ(U(ϑ))τ 1 for ℓ (ϑ 1)ρ+ [ρ], ϑ [ L]. Push to center. 12: Central Server 13: Compute V(ϑ) Pρ ℓ=1 V(ϑ 1)ρ+ℓ 14: Push V(ϑ) to nodes ℓ (ϑ 1)ρ + [ρ]. 15: Nodes ℓ= 1, ..., L 16: Compute Uℓ P k(Θ0)ℓV(ϑ) for ℓ (ϑ 1)ρ + [ρ], ϑ [ L]. Push to center. 17: Central Server 18: Compute U(ϑ) QR(Pρ ℓ=1 U(ϑ 1)ρ+ℓ) 19: Let (U(ϑ))τ U(ϑ). Push to nodes ℓ (ϑ 1)ρ + [ρ]. 20: end for 21: Central Server (implements Subspace Median on U(ϑ), ϑ [ L]) 22: Orthonormalize: Uϑ QR((Uϑ)0), ϑ [ρ] 23: Compute PUϑ UϑU ϑ , ϑ [ρ] 24: Compute GM: Pgm approx GM{PUϑ, ϑ [ρ]} (Use (Cohen, Lee, Miller, Pachocki, & Sidford, 2016, Algorithm 1) with parameter Tgm). 25: Find ϑbest = arg minϑ PUϑ Pgm F 26: Output Uout = Uϑbest if τ < 0.4 L/L; and if the initial estimate U0 satisfies SDF (U , U0) δ0 = 0.1/κ2, then w.p. at least 1 c0 t Ln 10 + exp( Lψ(0.4 τ, n 10)) , SDF (U , Ut+1) δt+1 := 1 (ησ 1 2) 0.12 We prove this lemma in the long version (Singh & Vaswani, 2024, Section V). The complete algorithm is obtained by using Algorithm 3 initialized using Algorithm 2 with sample-splitting. Combining Theorem 3.2 and Lemma 3.3, and setting η = 0.5/σ 1 2 and δ0 = 0.1/κ2, we can show that, at iteration t + 1, SDF (U , Ut+1) δt+1 = (1 0.06/κ2)t+10.1/κ2 whp. Thus, in order for this to be ϵ, we need to set T = Cκ2 log(1/ϵ). Also, since we are using fresh samples at each iteration (sample-splitting), this also means that our sample complexity needs to be multiplied by T. We have the following final result. Byzantine Resilient and Fast Federated Few-Shot Learning Algorithm 3 Byz-Alt GDMin-Learn: Complete algorithm 1: Obtain U0 using Algorithm 2. 2: for t = 1 to T do 3: Nodes ℓ= 1, ..., L 4: Set U Ut 1 5: (bk)ℓ ((Xk)ℓU) (yk)ℓ, k [q] 6: (θk)ℓ U(bk)ℓ, k [q] 7: ( fk)ℓ P k [q](Xk) ℓ((Xk)ℓU(bk)ℓ (yk)ℓ)(bk)ℓ , k [q] 8: Push fℓ P k [q]( Ufk)ℓ 9: Central Server 10: Compute f(ϑ) P ℓ ϑ fℓ 11: f GM approx GMthresh( f(ϑ), ϑ = 1, 2, . . . L). (Use (Cohen et al., 2016, Algorithm 1) with Tgm iterations on { f(ϑ), ϑ [ L] \ {ℓ: f(ϑ) > ω}}) 12: Compute U + QR(Ut 1 η ρ e m f GM) 13: return Set Ut U +. Push Ut to nodes. 14: end for Theorem 3.4. (Alt GDmin-Subs Mo M: Complete guarantee) Assume σ2 v = 0 and that maxk b k µ p r/qσ1(Θ ) for a constant µ 1. Consider Algorithm 3 and the setting of Theorem 3.2 and Lemma 3.3. Set T = Cκ2 log(1/ϵ). If emq Cκ4µ2(n + q)r2 log(1/ϵ)( L/L) and em > Cκ2 max(log q, log n) log(1/ϵ), then, w.p. at least 1 TLn 10, SDF (U , U) ϵ, and θk θ k ϵ θ k for all k [q]. The communication cost per node is order nr log( n ϵ ). The computational cost at any node is order nqr log( n ϵ ) while that at the center is n2 L log3( Lr/ϵ). The extension of the above result for the σ2 v = 0 case will be straightforward and can be proved using the same ideas as those used for Remark 2.2. 3.3. Numerical Experiments In the Figure 1 we plot Error vs Iteration where Error = SDF (U ,U) r . We report mean SDF over 100 Monte Carlo runs. We compare Byz-Fed-Alt GDmin-Learn (GMo M) with the baseline algorithm - Alt GDmin-Learn (Mean) in the no attack setting. We also provide results for Byz Fed-Alt GDmin-Learn (GM) for both values of Lbyz. All these are compared in Figure 1. We also compare the initialization errors in Figure 1 Table. As can be seen Byz Fed-Alt GDmin-Learn (GMo M) based initialization error is quite a bit lower than that with Byz-Fed-Alt GDmin-Learn (GM). The same is true for the GDmin iterations. Method Lbyz = 1 Lbyz = 2 Byz-Fed-Alt GDmin-Learn (GM) 0.716(0.665) 0.717(0.667) Byz-Fed-Alt GDmin-Learn (GMo M) 0.477(0.457) 0.475(0.459) 0 50 100 150 200 250 300 350 400 Iteration Iteration vs Error Mean(No attack) GMo M Lbyz=2 GMo M Lbyz=1 Figure 1: Table: Initialization errors. We report max SDF (mean SDF ) in each column. Figure: Byz-Fed Alt GDmin-Learn (GMo M), Alt GDmin-Learn (Mean), Byz-Fed Alt GDmin-Learn (GM) for Lbyz = 1, 2; L = 18. 4. Resilient Federated PCA Given q data vectors dk ℜn, that are zero mean, mutually independent, sub-Gaussian, and have covariance matrices that share the same principal subspace, the goal is to find this subspace. We can arrange the data vectors into an n q matrix, D := [d1, d2, . . . dq]. The data is vertically federated, this means that each node ℓhas qℓ= eq = q L dk s. Denote the corresponding sub-matrix of D by Dℓ. Suppose that dk has covariance matrix Σ k of the form Σ k EVD = [U , U ,k]Sk[U , U ,k] : all the covariance matrices share the same principal subspace U , but the lower eigenvectors and all eigenvalues can be different. We use K to denote the maximum sub-Gaussian norm (Vershynin, 2018, Chap 2) of Σ k 1/2dk for any k [q]. The goal is to obtain a resilient estimate of the r-dimensional subspace U of ℜn in a federated setting. The subspace median idea developed for initializing the Alt GDmin algorithm described earlier is in fact much more generally applicable for a generic subspace learning metaproblem: given L subspace estimates Uℓof an unknown subspace U , one can compute their subspace median using the exact same idea as that given in Sec. 3.1. For PCA, the individual node subspace estimates Uℓare computed as the top r singular vectors of the data matrix Dℓ. Moreover, we can also develop and analyze a subspace median of means generalization of it well. This requires some different ideas described next because, for the current problem, we are assuming vertical federation. Pick an integer L L. In order to implement the mean Byzantine Resilient and Fast Federated Few-Shot Learning Methods SVD-Res Cov Est Res Pow Meth Subs Med Pow Meth, no attack (Minsker, 2015, Cor 4.3) Modification of (Minsker, 2015; Hardt & Price, 2014) (Proposed) (baseline) Sample Comp for PCA n2L ϵ2 max n2r2, n ϵ2 (lower bound on q) Communic Cost n2 nr σ r log( n ϵ ) nr nr σ r log( n ϵ ) Compute Cost - node n2qℓ nqℓr σ r log( n ϵ ) nqℓr σ r log( n ϵ ) nqℓr σ r log( n ϵ ) Compute Cost - center n2L log3 Ln nr L σ r log( n ϵ ) log3 Ln n2L log3 Ln nr L σ r log( n Table 1: Comparisons for solving the resilient federated PCA problem (Sec. 4). We compare the proposed Subspace Median (Subs Med) algorithm with the two obvious (but bad) solutions SVD-Resilient Covariance Estimation (SVDRes Cov Est): SVD on GM of Covariance matrices, and Resilient Power Method (Res Pow Meth): GM based modification of the power method and with the baseline (power method for a no-attack setting). Observe that Subs Med needs the smallest sample complexity and has the lowest communication cost. step, we need to combine samples from ρ = L/ L nodes, i.e., we need to find the r-SVD of matrices D(ϑ) = [D(ϑ,1), D(ϑ,2), . . . , D(ϑ,ρ)], for all ϑ [ L]; we are using the notation (ϑ, ℓ) = (ϑ 1)ρ + ℓ. This needs to be done without sharing the entire data matrix. We do this by implementing L different federated power methods, each of which combines samples from a different minibatch of ρ nodes. The output of this step is L subspace estimates U(ϑ), ϑ [ L]. These serve as inputs to a basic Subspace Median algorithm to obtain the final Subspace-Mo M estimator. L = L is its subspace median special case. Theorem 4.1 (Resilient Federated PCA). Consider Subspace Median of Means. For a > 0, assume that minℓ((σ r)ℓ (σ r+1)ℓ) . Here Σ ℓ= 1 k SℓΣ k. Assume that the set of Byzantine nodes remains fixed for all iterations in this algorithm and the size of this set is at most τL with τ < 0.4 L/L. If q CK4 σ 1 2 then, then w.p. at least 1 c0 exp( Lψ(0.4 τ, 2 exp( n))), SDF (Uout, U ) ϵ. The communication cost is Tpownr = nr σ r log( n ϵ ) per node. The com- putational cost at the center is order n2 L log3 Lr ϵ . The computational cost at any node is order nqℓr Tpow. Comparison with attack-free federated PCA. Observe that the total sample complexity (lower bound on q) needed by the above result to guarantee SDF (U , U) ϵ is order nr L/ϵ2. Here we are quantifying subspace distance using SDF . However, even if we use the more common distance measure SD2(U , U) := (I UU )U and require just SD2(U , U) ϵ, this is the required sample complexity. The reason is we need Frobenius norm is for the GM computation. On the other hand, standard attack-free PCA needs a sample complexity of only n/ϵ2 to guarantee SD2(U , U) ϵ (Vershynin, 2018, Remark 4.7.2). Our complexity also has an extra factor of L; this is because we are computing the individual node estimates using eq = q/L data points and we need each of the node estimates to be accurate (to ensure that their median is accurate). This extra factor is needed also in other work that uses (geometric) median, e.g., in (Chen et al., 2017). Two more obvious solutions for Resilient PCA and why they fail. Consider the symmetric matrix Φℓ:= (Θ0)ℓ(Θ0)ℓ . In a centralized setting, the most obvious solution to the above problem would be to compute the GM of the vectorized matrices Φℓfollowed by obtaining the principal subspace (r-SVD) of the GM matrix; this was studied in (Minsker, 2015). However, in a federated setting, this is communication inefficient because it requires each node to share an n n matrix. For the same reason it is not private either. Moreover, this is extremely sample inefficient; see Table 1. For a communication-efficient solution, in the attack-free federated setting, one would use the distributed power method (Golub & Van Loan, 1989; Wu, Wai, Li, & Scaglione, 2018). A direct modification of this to deal with attacks is to use its GM based modification: at each iteration, instead of summing the n r matrices, Uℓ:= (ΦℓU) received from each node, we compute the GM of their vectorized versions. We refer to this as Resilient Power Method (Res Pow Meth). However, this works w.h.p. only if all the Φℓ s are extremely accurate estimates of Φ = Θ Θ (Singh & Vaswani, 2024). We summarize this discussion in Table 1. 5. Conclusions and Future Work We developed a Byzantine-resilient, sample-, time-, and communication-efficient solution, called Byz-Alt GDmin, for few shot learning. We also introduced a novel solution approach, called Subspace Median, for combining subspace estimates from multiple federated nodes when some of them can be malicious. This is likely to be of independent interest for developing a secure initialization approach for various federated low rank matrix recovery, and subspace learning and tracking problems. The few shot learning problem is almost synonymous with the online subspace tracking problem studied in (Babu et al., 2023) for real-time dynamic MRI. Mini-batch subspace tracking ideas of this work can be useful for few shot learning as well. We will explore real data applications in future. Byzantine Resilient and Fast Federated Few-Shot Learning Acknowledgements The authors would like to thank Prof. Shana Moothedath of Iowa State University for sharing the linear representation learning (Du et al., 2020) paper with us. Addressing reviewer comments We revised the part of Abstract, Introduction, Contributions, and Related Work section to include the work by (Collins et al., 2021) which we missed by mistake. We significantly shortened the section 2. This work does not have the space to add real-world data applications. However we will do this as part of future work by modifying the approaches developed in (Babu et al., 2023). Byzantine Resilient and Fast Federated Few-Shot Learning Impact Statement This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here. Alistarh, D., Allen-Zhu, Z., & Li, J. (2018). Byzantine stochastic gradient descent. Advances in Neural Information Processing Systems, 31. Allen-Zhu, Z., Ebrahimian, F., Li, J., & Alistarh, D. (2020). Byzantine-resilient non-convex stochastic gradient descent. ar Xiv preprint ar Xiv:2012.14368. Babu, S., Lingala, S. G., & Vaswani, N. (2023). Fast low rank compressive sensing for accelerated dynamic MRI. IEEE Trans. Comput. Imag.. Baxter, J. (2000). A model of inductive bias learning. Journal of artificial intelligence research, 12, 149 198. Boudiaf, M., Ziko, I., Rony, J., Dolz, J., Piantanida, P., & Ben Ayed, I. (2020). Information maximization for few-shot learning. Advances in Neural Information Processing Systems, 33, 2445 2457. Cao, X., Fang, M., Liu, J., & Gong, N. Z. (2020). Fltrust: Byzantine-robust federated learning via trust bootstrapping. ar Xiv preprint ar Xiv:2012.13995. Cao, X., & Lai, L. (2019). Distributed gradient descent algorithm robust to an arbitrary number of byzantine attackers. IEEE Transactions on Signal Processing, 67(22), 5850 5864. Chen, Y., Chi, Y., Fan, J., Ma, C., et al. (2021). Spectral methods for data science: A statistical perspective. Foundations and Trends textregistered in Machine Learning, 14(5), 566 806. Chen, Y., Su, L., & Xu, J. (2017). Distributed statistical machine learning in adversarial settings: Byzantine gradient descent. Proceedings of the ACM on Measurement and Analysis of Computing Systems, 1(2), 1 25. Cohen, M. B., Lee, Y. T., Miller, G., Pachocki, J., & Sidford, A. (2016). Geometric median in nearly linear time. In Proceedings of the forty-eighth annual ACM symposium on Theory of Computing (pp. 9 21). Collins, L., Hassani, H., Mokhtari, A., & Shakkottai, S. (2021). Exploiting shared representations for personalized federated learning. In International conference on machine learning (pp. 2089 2099). Data, D., & Diggavi, S. (2021). Byzantine-resilient SGD in high dimensions on heterogeneous data. In 2021 IEEE International Symposium on Information Theory (ISIT) (pp. 2310 2315). Du, S. S., Hu, W., Kakade, S. M., Lee, J. D., & Lei, Q. (2020). Few-shot learning via learning the represen- tation, provably. ar Xiv preprint ar Xiv:2002.09434. Fei-Fei, L., Fergus, R., & Perona, P. (2006). One-shot learning of object categories. IEEE transactions on pattern analysis and machine intelligence, 28(4), 594 611. Ghosh, A., Hong, J., Yin, D., & Ramchandran, K. (2019). Robust federated learning in a heterogeneous environment. ar Xiv preprint ar Xiv:1906.06629. Golub, G. H., & Van Loan, C. F. (1989). Matrix computations. The Johns Hopkins University Press, Baltimore, USA. Goyal, P., Mahajan, D., Gupta, A., & Misra, I. (2019). Scaling and benchmarking self-supervised visual representation learning. In Proceedings of the ieee/cvf international conference on computer vision (pp. 6391 6400). Hardt, M., & Price, E. (2014). The noisy power method: A meta algorithm with applications. Advances in neural information processing systems, 27. Jain, P., Kar, P., et al. (2017). Non-convex optimization for machine learning. Foundations and Trends in Machine Learning, 10(3-4), 142 363. Li, L., Xu, W., Chen, T., Giannakis, G. B., & Ling, Q. (2019). RSA: Byzantine-robust stochastic aggregation methods for distributed learning from heterogeneous datasets. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 33, pp. 1544 1551). Li, Y., Ildiz, M. E., Papailiopoulos, D., & Oymak, S. (2023). Transformers as algorithms: Generalization and stability in in-context learning. In International conference on machine learning (pp. 19565 19594). Lu, S., Li, R., Chen, X., & Ma, Y. (2022). Defense against local model poisoning attacks to byzantinerobust federated learning. Frontiers of Computer Science, 16(6), 166337. Maurer, A., Pontil, M., & Romera-Paredes, B. (2016). The benefit of multitask representation learning. Journal of Machine Learning Research, 17(81), 1 32. Minsker, S. (2015). Geometric median and robust estimation in banach spaces. Nayer, S., Narayanamurthy, P., & Vaswani, N. (2019). Phaseless PCA: Low-rank matrix recovery from column-wise phaseless measurements. In Intnl. conf. machine learning (icml). Nayer, S., Narayanamurthy, P., & Vaswani, N. (2020, March). Provable low rank phase retrieval. IEEE Trans. Info. Th.. Nayer, S., & Vaswani, N. (2021). Sample-efficient low rank phase retrieval. IEEE Trans. Info. Th.. Nayer, S., & Vaswani, N. (2023, on ar Xiv since Feb. 2021, Feb.). Fast and sample-efficient federated low rank matrix recovery from column-wise linear and quadratic projections. IEEE Trans. Info. Th.. Byzantine Resilient and Fast Federated Few-Shot Learning Netrapalli, P., Jain, P., & Sanghavi, S. (2013). Low-rank matrix completion using alternating minimization.. Pillutla, K., Kakade, S. M., & Harchaoui, Z. (2019). Robust aggregation for federated learning. ar Xiv preprint ar Xiv:1912.13445. Ravi, S., & Larochelle, H. (2016). Optimization as a model for few-shot learning. In International conference on learning representations. Regatti, J., Chen, H., & Gupta, A. (2022). Byzantine Resilience With Reputation Scores. In 2022 58th Annual Allerton Conference on Communication, Control, and Computing (Allerton) (pp. 1 8). Shen, Z., Ye, J., Kang, A., Hassani, H., & Shokri, R. (2023). Share your representation only: Guaranteed improvement of the privacy-utility tradeoff in federated learning. ar Xiv preprint ar Xiv:2309.05505. Singh, A. P., & Vaswani, N. (2024). Byzantine-resilient federated pca and low rank column-wise sensing. ar Xiv preprint ar Xiv:2309.14512. Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical networks for few-shot learning. Advances in neural information processing systems, 30. Sun, C., Shrivastava, A., Singh, S., & Gupta, A. (2017). Revisiting unreasonable effectiveness of data in deep learning era. In Proceedings of the ieee international conference on computer vision (pp. 843 852). Sung, F., Yang, Y., Zhang, L., Xiang, T., Torr, P. H., & Hospedales, T. M. (2018). Learning to compare: Relation network for few-shot learning. In Proceedings of the ieee conference on computer vision and pattern recognition (pp. 1199 1208). Tripuraneni, N., Jin, C., & Jordan, M. (2021). Provable meta-learning of linear representations. In International conference on machine learning (pp. 10434 10443). Tripuraneni, N., Jordan, M., & Jin, C. (2020). On the theory of transfer learning: The importance of task diversity. Advances in neural information processing systems, 33, 7852 7862. Tziotis, I., Shen, Z., Pedarsani, R., Hassani, H., & Mokhtari, A. (2022). Straggler-resilient personalized federated learning. ar Xiv preprint ar Xiv:2206.02078. Vaswani, N. (2024). Efficient federated low rank matrix recovery via alternating gd and minimization: A simple proof. IEEE Trans. Info. Th.. Vershynin, R. (2018). High-dimensional probability: An introduction with applications in data science (Vol. 47). Cambridge university press. Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al. (2016). Matching networks for one shot learning. Advances in neural information processing systems, 29. Wu, S. X., Wai, H.-T., Li, L., & Scaglione, A. (2018). A review of distributed algorithms for principal component analysis. Proceedings of the IEEE, 106(8), 1321 1340. Xie, C., Koyejo, S., & Gupta, I. (2019). Zeno: Distributed stochastic gradient descent with suspicionbased fault-tolerance. In International Conference on Machine Learning (pp. 6893 6901). Yin, D., Chen, Y., Kannan, R., & Bartlett, P. (2018). Byzantine-robust distributed learning: Towards optimal statistical rates. In International Conference on Machine Learning (pp. 5650 5659). Yu, M., Guo, X., Yi, J., Chang, S., Potdar, S., Cheng, Y., ... Zhou, B. (2018). Diverse few-shot text classification with multiple metrics. ar Xiv preprint ar Xiv:1805.07513.