# fedpft_federated_proxy_finetuning_of_foundation_models__9f203b22.pdf Fed PFT: Federated Proxy Fine-Tuning of Foundation Models Zhaopeng Peng1 , Xiaoliang Fan1, , Yufan Chen1 , Zheng Wang1 , Shirui Pan2 , Chenglu Wen1 , Ruisheng Zhang3 , Cheng Wang1 1Fujian Key Laboratory of Sensing and Computing for Smart Cities, School of Informatics, Xiamen University 2School of Information and Communication Technology, Griffith University 3School of Information Science and Engineering, Lanzhou University pengzhaopeng@stu.xmu.edu.cn, fanxiaoliang@xmu.edu.cn, {yufanchen, zwang}@stu.xmu.edu.cn, s.pan@griffith.edu.au, clwen@xmu.edu.cn, zhangrs@lzu.edu.cn, cwang@xmu.edu.cn Adapting Foundation Models (FMs) for downstream tasks through Federated Learning (FL) emerges a promising strategy for protecting data privacy and valuable FMs. Existing methods finetune FM by allocating sub-FM to clients in FL, however, leading to suboptimal performance due to insufficient tuning and inevitable error accumulations of gradients. In this paper, we propose Federated Proxy Fine-Tuning (Fed PFT), a novel method enhancing FMs adaptation in downstream tasks through FL by two key modules. First, the sub-FM construction module employs a layer-wise compression approach, facilitating comprehensive FM fine-tuning across all layers by emphasizing those crucial neurons. Second, the sub-FM alignment module conducts a two-step distillations layerlevel and neuron-level before and during FL finetuning respectively, to reduce error of gradient by accurately aligning sub-FM with FM under theoretical guarantees. Experimental results on seven commonly used datasets (i.e., four text and three vision) demonstrate the superiority of Fed PFT. Our code is available at https://github.com/pzpdzd/Fed PFT. 1 Introduction In recent years, various transformer-based Foundation Models (FMs) [Bommasani et al., 2021] such as BERT [Kenton and Toutanova, 2019], GPT [Radford et al., ], LLa MA [Touvron et al., 2023], and Vi T [Dosovitskiy et al., 2020] have attained state-of-the-art performance across a diverse range of natural language processing (NLP) and computer vision (CV) tasks, yet also face both data privacy and FM copyright concerns. For instance, a FM trained on medical data might inadvertently memorize sensitive patient information, and companies that own closed-source FMs may choose not to share FMs with the public. Federated Learning (FL) [Mc Mahan et al., 2017] offers a privacy-preserving approach for collaborative fine-tuning of FMs among multiple participants. This Corresponding Author approach is increasingly promising for FM fine-tuning applications, ensuring the adaptation of downstream tasks without directly sharing client data and server FM. Recent methods [Xiao et al., 2023; Marchisio et al., 2023] mainly aim to fine-tune FMs without using the full model, which leverage layer-drop techniques [Sajjad et al., 2023] to compress a FM and derive a sub-FM, enabling approximate fine-tuning of the original FM. However, these methods still pose two significant challenges that adversely reduce the performance of fine-tuned FMs. On one hand, they failed to fine-tune FMs sufficiently as a result of discarding those intermediate layers of FMs, consequently leading to the performance degradation of fine-tuned FMs. As shown in Fig.1(a), layer-drop methods fail to update intermediate layers of the FM during fine-tuning, due to the mismatch between the FM and the constructed sub-FM. On the other hand, there is a potential defect for the accumulation of gradient errors of FMs due to the lack of alignment between sub-FMs and FMs during FL fine-tuning, subsequently leading to further performance degradation. Fig.1(b) shows that, due to the absence of alignment, existing methods might accumulate significant gradients update errors between the FM and its constructed sub-FM during the FL fine-tuning process. To address the above two challenges, we propose a framework called Federated Proxy Fine-Tuning (Fed PFT) to enhance the adaptation of FMs for downstream tasks, while neither server FMs nor client data are directly shared. First, we design the sub-FM construction module, which performs layer compression on FMs to obtain sub-FMs by measuring neurons saliency of Feed-Forward Network (FFN) in transformer, facilitating comprehensive fine-tuning of FMs by emphasizing those crucial neurons. Second, we design the sub-FM alignment module, which conducts a two-step distillations layer-level and neuron-level before and during FL fine-tuning respectively, ensuring the accurate alignment between sub-FMs and FMs with a theoretical guarantee. Extensive experiments on three FMs and seven commonly used datasets demonstrate that Fed PFT outperforms existing baselines that fine-tune FMs without using the full model. Our contributions can be summarized as follows: We introduce Fed PFT, a novel federated fine-tuning of FM method that establishes a sub-FM as a local proxy. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI-24) (a) Two types of proxy sub-FM construction (b) The problem of accumulating gradients errors Figure 1: A motivating example of two challenges in FM fine-tuning using proxy sub-model. (a) Existing methods constructing sub-FMs via layer-drop compression discard intermediate layers in FM, causing mismatched and insufficient fine-tuning, while Fed PFT conducting layerwise compression ensures comprehensive fine-tuning of FM; and (b) as FL fine-tuning progresses, the discrepancy between the updates made by sub-FMs and FMs grows, leading to a deviation from the ideal update direction, while Fed PFT aims to mitigate this gap by accurately aligning sub-FMs and FMs. Fed PFT effectively improves fine-tuning performance while maintaining the critical constraint that neither the server FM nor the client data is directly shared. We propose the first module for constructing sub-FMs through layer-wise compression. This technique maintains layer correspondence across sub-FMs and FMs, ensuring the comprehensive fine-tuning of FM layers while also considering the alleviation of training overhead. We propose the second module to align sub-FMs with FMs via a two-step distillation layer-level and neuronlevel before and during FL fine-tuning respectively. Additionally, we offer theoretical insights into the significance of distillation for fine-tuning using sub-model. We conduct extensive experiments on three FMs and seven commonly used datasets. Results demonstrate that Fed PFT consistently outperforms existing baselines. 2 Related Works 2.1 FM Fine-Tuning Through FL Traditional centralized fine-tuning faces privacy concerns due to data sharing. Recent works [Chen et al., 2023a; Yu et al., 2023; Zhuang et al., 2023] introduce the concepts of Federated Foundation Models, to alleviate privacy concerns. [Fan et al., 2023; Kuang et al., 2023] propose various Fed LLM platforms to support federated training of LLMs. [Xu et al., 2023] fine-tune FM via FL on mobile devices. [Chen et al., 2023b] apply FM to federated multi-task learning. [Chen et al., 2023c] save the communication cost during FL training through block-level parameters dropout. [Wang et al., 2023a] reduce the communication and computation cost by training different layers of BERT in each round. [Zhang et al., 2023] apply parameter-efficient fine-tuning (PEFT) methods to federated fine-tuning of FMs for privacy defense. However, most of aforementioned methods rely on sharing the server FM. This limitation may pose risks of FM copyright leakage and impose a substantial computational burden on clients. 2.2 FM Fine-Tuning Without Using the Full Model Early PEFT methods, including Lora [Hu et al., 2021], Adapter-Tuning [Houlsby et al., 2019], and Prefix-Tuning [Li and Liang, 2021], focus on efficient fine-tuning of complete FMs by reducing the number of tunable parameters. Despite these efforts, the gradient computation for tunable parameters still needs backpropagation through the entire FM [Sung et al., 2022]. Recently, Offsite-Tuning [Xiao et al., 2023] is proposed to achieve fine-tuning without the full model. In this approach, the FM owner sends a light-weight adapter and an emulator constructed through layer-drop and knowledge distillation [Hinton et al., 2015] to clients. Clients then finetune the adapter on their data with support from the emulator. The refined adapter is subsequently returned and incorporated into the full model for the fine-tuning process. Similarly, mini-model adaptation [Marchisio et al., 2023] constructs a shallow mini-model from a fraction of the FM s parameters. However, those methods either discard significant amount of intermediate layers in FM or face the problem of gradient error accumulation, resulting in sub-optimal fine-tuning performance. Different from conventional methods, we construct sub-FMs based on layer-wise compression and mitigate gradient error accumulation by a two-step distillations. 3.1 Preliminary Federated Learning Given N parties Pi(i = 1, ...., N), each party holds data Di. Let L( , ) be the loss function. FL aims to train a machine learning model Θ using the dataset D = Di(i = 1, ..., N) Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI-24) Figure 2: The overall framework of Fed PFT that enhances FMs adaptation in downstream tasks through FL by two key modules: (1) Sub-FM Construction Module constructs sub-FM by layer-wise compression to facilitate comprehensive FM fine-tuning; and (2) Sub-FM Alignment Module aligns sub-FM by two-step distillation to ensure accurate alignment between sub-FM and FM with a theoretical guarantee. under the coordination of a server S, while the raw data of all parties are not directly shared, which is formally described as Θ = arg min Θ |D| L(Θ, Di). (1) Foundation Model Fine-tuning Given a foundation model Θ = {W1, W2, ..., Wn} and a downstream task dataset D, the fine-tuning aims to obtain a new model Θ = {W 1 , W 2 , ..., W n}, it is Θ = Θ + Θ, Θ = arg min Θ L(Θ + Θ, D). (2) 3.2 Problem Definition For FM fine-tuning using proxy sub-model, we first construct a sub-model Θ = {W1, W2, ..., Wk, W k+1, ..., W n} with fewer parameters for Θ to act as a proxy. Second, finetune the proxy sub-model Θ using the dataset D. Finally synchronize the updated gradients on Θ to Θ. Specifically, we construct Θ by compressing Θ, and retain a portion of the parameter matrix in Θ during the compression process. This compression process is formally described as follows: Θ = Θ1 C(Θ2), (3) where Θ1 Θ2 = Θ, and C( ) denotes the compression method. During the fine-tuning of Θ , we update only Θ1 and synchronize the updated gradient on Θ1 into Θ after finetuning to obtain Θ s approximation of Θ , which is formally described as Θ = (Θ1 + Θ 1) Θ2, Θ 1 = arg min Θ 1 L((Θ1 + Θ 1) C(Θ2), D). (4) 3.3 Method Overview The overall framework of ours Fed PFT is shown in Fig.2. We first derive a proxy sub-FM for the server FM, then collaboratively fine-tune the sub-FM through FL, and finally synchronise the updates on the sub-FM to the FM by pluggingin. Fed PFT enhances downstream tasks adaptation of FMs through FL by two key module: (1) Sub-FM Construction Module that constructs sub-FMs by performing layer-wise compression on FMs based on neuron saliency; and (2) Sub FM Alignment Module that reduces the difference between FMs and sub-FMs by layer-level and neuron-level knowledge distillation before and during FL fine-tuning, respectively. We will introduce those two modules in details as follows. 3.4 Sub-FM Construction Module Based on Layer-Wise Compression Transformer-based FM typically consist of three parts: an embedding layer, a task head, and a sequence of transformer layers. Since the size of FM is dominated by all transformer layers, we perform compression for each transformer layer. Each transformer layer contains two sub-layers: Multi Head Attention (MHA) and Feed-Forward Network (FFN), each of which applies residual connection and followed by layer normalization. The output of MHA is MHA(x) = Concat(Attn0(x), ..., Attnh(x))W O, Attn(x) = softmax(x W Q(x W K)T dk )x W V , (5) where W Q Rdmodel dk, W K Rdmodel dk, W V Rdmodel dk and W O Rdmodel dmodel are the weight matrices of query, key, value, and output in MHA, respectively. h is the number of attention heads, dk and dmodel are the dimensions of key and FM, respectively, and dmodel = dk h. The parameters number of MHA is about 4d2 model. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI-24) The output of FFN is FFN(x) = gelu(x W1 + b1)W2 + b2, (6) where W1 Rdmodel dff and W2 Rdff dmodel are the weight matrices of two linear layers in FFN, respectively, b1 Rdff and b2 Rdmodel are the bias, dff is the dimensions of FFN and is usually set to 4 dmodel. The parameters number of FFA is about 8d2 model. Obviously, it is that most of the parameters in transformer layer are contained in FFN. Hence, we opt to compress the FFN rather than the MHA of each layer for sub-FM construction. This minimizes the parameters number of sub-FM while ensuring a consistent set of trainable parameters (i.e. MHA) between the FM and its sub-FM at each layer. We accomplish layer-wise compression by systematically removing neurons with low saliency in the FFN of each layer, employing a fixed ratio. First, by further transforming (6), we can represent the output of FFN as the sum of dff neurons outputs: i=1 (gelu(xui + b1i)wi) + b2, (7) where wi Rdmodel is the ith column vector in W2, ui Rdmodel is the ith row vector in W1, b1i is the ith item in b1. Second, based on (7) and magnitude-based pruning method [Wen et al., 2016], we use the L2-norm of all connect weights of neuron to measure its saliency, that is: Saliency(i) = j=1 (w2 ij + u2 ij), (8) where i is the index of neurons in FFN. Finally, we construct a sub-FM serving as a proxy for the FM, accomplished by systematically eliminating neurons with low saliency in each layer at a fixed ratio. 3.5 Sub-FM Alignment Module Based on Two-Step Knowledge Distillation In accordance with the description of FM fine-tuning using proxy sub-model in 3.2, it is evident that the FM fine-tuning is entirely contingent on the gradient descent of its sub-FM. This fine-tuning methodology prompts a fundamental question: How can we ensure the convergence of FM to the optimal value with the assistance of its sub-FM? Theorem 1. Suppose both the function f : Rn R and its approximation f : Rn R are convex and differentiable, and their gradient are Lipschitz continuous with constant L1>0 and L2>0, respectively, i.e. we have that f(x) f(y) 2 L1 x y 2 and f (x) f (y) 2 L2 x y 2 for any x, y. Then if we run gradient descent for k iterations on f with a fixed step size η 1 L1 and synchronize the gradient to f, let f f = δ, when satisfying 2 f 2 2, (9) i=1 δ(i) 2 2 i=1 δ(i), x(i) x , (10) it will yield a solution f (k) which satisfies f(x(k)) f(x ) x(0) x 2 2 2ηk , (11) where f(x ) is the optimal value. Proof. See Appendix1.A Intuitively, Theorem.1 indicates that when (9) and (10) are satisfied, gradient descent of FM with the help of sub-FM is guaranteed to converge and converges with rate O( 1 k). It is evident that both conditions (9) and (10) are constraints on the difference between the actual and ideal update gradients of FM, and thus how to minimize the difference of the update gradients becomes a problem to be solved in the next step. Theorem 2. For a transformer, let the number of attention head be 1, and ignoring its nonlinear function and residual connection, its output can be expressed as y = x W Q(x WK)T x W V W OW1W2, let A = W Q(W K)T , B = W V W O, C = W1W2, then y = x Ax T x BC, and the output of its corresponding sub-layer after compressing FFN layer is expressed as y = x Ax T x BC , assuming that the gradient of loss function loss = f(y) is Lipschitz continuous with constant L3>0 and C C 2 2 ϵ1, y y 2 2 ϵ2, there exists the constant K1>0 and K2>0 such that A 2 2 K1ϵ1 + K2ϵ2, B 2 2 K1ϵ1 + K2ϵ2, (12) Proof. See Appendix1.B From Theorem.2, it is evident that shrinking the error of gradients can be achieved by narrowing the difference in output and weights between the sub-FM and FM. Based on the above analysis, we grasp the importance of narrowing the difference between sub-FM and FM via knowledge distillation to boost the performance of FM that finetuned using sub-FM. Therefore, we propose a method to align sub-FM using layer-level and neuron-level distillations in two phases, before and during FL fine-tuning, respectively. These two distillation methods are shown in the Fig.3. Layer-Level Distillation Before FL Fine-Tuning Given that our sub-FMs are constructed based layer-wise compression, where each layer retains a set of tunable parameters (i.e., MHA), we leverage the outputs from all layers to compute the layer-level distillation loss. Furthermore, based on Theorem.2, we enhance the aforementioned distillation loss by introducing a regularization term. The purpose of this regularization term is to quantify the disparity between the weights of FFN and sub-FFN in each layer, to further facilitate a thorough knowledge transfer during fine-tuning by refining the alignment process. Thus, 1https://arxiv.org/abs/2404.11536 Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI-24) Figure 3: An example of two distillation processes the final distillation loss is denoted as: LKD = 1 LMKD j=1 O(i) j O + µ W (i) 1 W (i) 2 W (i) 2 2 2), (13) where L is the number of layers, MKD is the size of distill dataset, O(i) j is the output of the jth sample in the ith layer, W (i) 1 and W (i) 2 are the two weight matrices of the ith FFN, µ is the regularization coefficient. Neuron-Level Distillation During FL Fine-Tuning In addition, the absence of alignment between FM and its constructed sub-FM during FL fine-tuning may cause the actual gradient update direction of the FM to gradually deviate from its ideal direction. This deviation can substantially reduce the performance of the fine-tuned FM. To mitigate the problem, we re-align the sub-FM with the FM after FL aggregation in certain rounds. However, since the datasets for distillation and the datasets for local fine-tuning are typically collected from different domains, excessively aligning sub-FMs through distillation may hinder the adaptation of sub-FMs to downstream tasks. Inspired by [Mallya and Lazebnik, 2018], during the alignment process in FL fine-tuning, we opt to update only a subset of neurons with low saliency in local fine-tuning to prevent the risk of sub-FM forgetting knowledge of local data. Moreover, since all FFNs of sub-FM are not updated during FL fine-tuning, the effectiveness of magnitude-based neuron saliency measurement methods diminishes. To address this, we opt to select neurons for updating during alignment based on the Average Percentage of Zero activations (APo Z) of outputs on the downstream task dataset [Hu et al., 2016]. The APo Z(i) k of the kth neuron in ith layer is defined as: APo Z(i) k = 1 MDT S l=1 I(O(i) jkl = 0), (14) where MDT is the size of the downstream task dataset, S is the sequence length of the jth sample, O(i) jkl is the output of the lth token of the jth sample at the kth neuron in ith layer, I( ) is the indicator function. We calculate the APo Z for each neuron on the client using the local dataset before each round that requires alignment, and subsequently select the neurons that need to be updated during alignment based on their APo Z values. 3.6 Cost Analysis We perform a theoretical analysis of the computational and communication cost of Fed PFT based on BERT, and other models such as Ro BERTa and Vi T are similar. Following the settings in [Wang et al., 2023a], we assume that all FL clients have the same training settings and exclude external differences such as local dataset size and hardware environment. Computational Cost Given a BERT model, let V be the vocabulary size, S be the sequence length, L be the number of layers, and cf,cb be the number of forward propagation and backward propagation respectively. Based on the analysis in 3.4, the computational cost of a BERT model is O(dmodel(V + S) + L(4Sd2 model + S2dmodel) + 2LSdmodeldff + LSdmodel) where the four terms denote the cost of embedding, MHA, FFN and Add&Norm, respectively. Based on the above information, the overall time complexity of the full model is computed as follows. First, the time complexity of embedding is O(dmodel(V + S)). Second, due to dff = 4dmodel, the forward propagation time complexity is about O(cf L(12Sd2 model + S2dmodel)). Identically, the backward propagation time complexity is O(cb L(12Sd2 model + S2dmodel)). Finally, the overall time cost is O(dmodel(V + S) + L(cf + cb)(12Sd2 model + S2dmodel)) and we have dmodel(V + S) L(cf + cb)(12Sd2 model + S2dmodel) and S