Thinking Forward: Memory-Efficient Federated Finetuning of Language Models
Spry is a federated learning algorithm that enables finetuning LLMs using Forward-mode Auto Differentiation; to achieve low memory footprint, high accuracy, and fast convergence.
摘要
评审与讨论
This paper introduces SPRY, a federated learning (FL) algorithm designed to finetune large language models (LLMs) on resource-constrained devices by addressing the excessive memory requirements of traditional backpropagation methods. SPRY tackles the challenge of high memory usage from intermediate activations by utilizing Forward-mode Auto-Differentiation (AD) and splitting trainable weights among participating clients, allowing each client to compute accurate gradients with reduced memory. Theoretical analysis shows SPRY's global gradients are unbiased for homogeneous data and provides a convergence rate dependent on FL rounds and data heterogeneity. Empirical results demonstrate SPRY's efficiency, reducing memory usage by 1.4-7.1× and achieving faster convergence and higher accuracy compared to existing methods. This makes feasible the finetuning of LLMs on mobile and edge devices, significantly impacting FL deployments.
优点
- S.1. The proposed SPRY algorithm tackles a difficult task and is backed up with both empirical and theoretical results.
- S.2. The paper is well written and the illustrations are helpful.
- S.3. The empirical experiments are conducted on multiple datasets, models, and hardware configuration, while showing that SPRY outperforms previous works.
- S.4. The paper provides an anonymous code repository of SPRY.
缺点
- W.1. While SPRY shows promising results with a relatively small number of clients, the scalability to a larger number of clients, which is typical in FL scenarios, is not thoroughly investigated. The communication and computational overheads associated with increasing the number of participating clients, especially in terms of synchronization and gradient aggregation, are not discussed. This raises concerns about the practicality and efficiency of SPRY in large-scale deployments.
- W.2. The empirical evaluation, although thorough in certain aspects, is limited in terms of the variety of datasets and models used. The evaluation focuses primarily on specific language tasks and a narrow range of LLMs. A broader evaluation including more diverse datasets and model architectures would strengthen the generalizability of the findings. Additionally, the impact of SPRY on tasks beyond natural language processing, such as computer vision or other modalities, is not explored.
- W.3. While SPRY does outperform existing zero-order methods, it converges slower (time-wise) compared to traditional gradient based methods such as FedAvg.
问题
- Q.1. Will the SPRY algorithm work in a large scale setting of 1k+ GPUs?
局限性
n/a
W1. Scalability to a larger number of clients
Please refer to our answer to reviewer KHem under “W2. Impact of the number of clients on performance”.
W1. Communication and computational overheads
For communication and computational overheads, please see “1. Communication overhead” and “2. Computation costs” in the global “Author Rebuttal”.
To summarize,
(a) As the participating client count increases, both per-epoch and per-iteration modes of Spry have a lower communication cost are communicatively lower-cost than classic backprop-based FL algorithm FedAvg and other finite-differences based baselines. In FedAvg, the communication cost of the server to clients scales linearly with the number of participating clients for the entire global model. In contrastOn the bright side, Spry’s communication cost only scales with max(layer count, number of participating clients) for a subset of the global model.
(b) Spry is more computationally efficient than FedAvg at the server-side, since Spry aggregates only a subset of model weights updated by each client, instead of aggregating all model weights from all the clients.
We next discuss in detail the server-side computation costs related gradient aggregation and synchronization at the server-side in the following two parts:
• Synchronization
A server in Spry needs to assign layers randomly to participating clients for each round, and it needs to keep that mapping of layers-to-client-ids for aggregation. As shown in Algorithm 1, Appendix E; this mapping is not a bottleneck for the server since it’s a simple loop to map layer ids to client ids (See lines 14 to 22 in Algorithm 1).
Moreover, for per-iteration communication frequency, the server incurs an additional cost of , where is the size of a layer (we have assumed the same layer size for each layer for ease of exposition). As shown in Table 2 of the PDF attached to the global “Author Rebuttal”, this cost is also a factor for per-iteration finite differences methods. For Spry, the additional cost comes from having to generate the random perturbations at the server-side and multiplying it with jvp scalar values received from the client.
• Gradient Aggregation
The only major difference between Spry and backpropagation-based FedAvg at server-side is that with Spry, instead of aggregating all layer weights from all clients, the server only needs to aggregate layer weights from the clients assigned to that layer. Hence Spry has lower computation cost than FedAvg by not having to aggregate all layer weights from all clients.
W2. Generalizability of Spry
Spry focuses on fine-tuning language models of sizes between 17.9M and 13B mainly because of their popularity and practicality on edge devices. This goal and the scope of the experimentations align well with recent efforts published at both NeurIPS and other top venues [1, 2, 3].
Nonetheless, we value the reviewer’s feedback and aim to do more evaluation on vision and a wider variety of language tasks in the future.
[1] MeZO: Fine-Tuning Language Models with Just Forward Passes (Malladi et al., NeurIPS 2024)
[2] Distributed inference and fine-tuning of large language models over the internet (Borzunov et al., NeurIPS 2023)
[3] Pockengine: Sparse and efficient fine-tuning in a pocket (Zhu et al., MICRO 2023)
W3. Slow convergence of Spry compared to FedAvg
We have acknowledged the slower convergence of Spry compared to FedAvg in general. However, we would like to emphasize that Spry’s much reduced peak memory consumption makes it feasible for language model finetuning on devices, which is infeasible with FedAvg.
We must note that a faster gradient calculation method would have no utility for fine-tuning or training large models on edge devices, if its peak memory footprint is larger than available memory of that device. This is the trade-off between runtime and peak memory consumption of Spry, which sacrifices time to convergence by 2.65x to gain 27.90% to 86.26% memory reduction, while achieving 0.6% to 6.2% of the accuracy of the best-performing FL backpropagation.
As a side note, we also reiterate that Spry does gain a speedup of 1.14x against FedAvg for a medium-sized language model, RoBERTa Large. This speedup is attributed to the fact that clients in Spry only need to train a subset of layers, unlike all FedAvg clients training all the layers.
Q1. Large scale setting of 1k+ GPUs
For now it’s computationally slower for us to simulate federated training with 1k participating clients. It takes ~4h to simulate and execute 1k clients per round, running 500 such rounds would take ~83 days. However, given our experiments on 1k total clients (See Table 1 of Spry: rows of AG News, SNLI, MNLI, Yahoo, Yelp) and ablation studies (“Effects of participating client count” in Section 5.4 and Appendix F) on increasing the participating client ratio, we anticipate Spry achieving better performance as the participating client count scales up. Our expectations are based on the observations that more clients training the same model weights decreases the forward gradient’s approximation error, leading to a better prediction performance.
Hence, the layer-to-client mapping is not a bottleneck at the server for Spry; and gradient aggregation and synchronization of Spry can be faster than methods which let all the clients compute all gradients. We observe that Spry can be more scalable than backpropagation and finite-difference-based methods, while being 1.46x to 28.57x faster than finite-difference methods and takes 27.90% to 86.26% less memory than backpropagation-based methods, making it a feasible and scalable way to train larger models on edge devices.
Thank you for the detailed answers.
The provided results and details satisfy some of my concerns, however, I find the empirical results somewhat limited in their variety and therefore I'll be keeping my score of slightly leaning towards accept.
We are glad to know that we were able to address the reviewer's concerns on the scalability and computation/communication overhead of our work. We will work towards incorporating all the discussed points into our manuscript.
This manuscript is focused on the memory-efficient federated finetuning of LLMs. The author first uses Forward-mode Auto-Differentiation to reduce memory. Then, the author observes that merely substituting backpropagation with Forward-mode AD in FL scenarios often results in poor accuracy and computational inefficiency. To address this challenge, the author recognizes that Forward-mode AD operates more efficiently and yields better gradient estimations when the trainable weight count is minimized. Therefore, SPRY only assigns each client a responsibility to compute gradients for only a subset of trainable weights. The experimental results show that the memory overhead can be reduced significantly. The topic is of interest and the presented numerical results seem, indeed, promising. However, there are still some questions/comments/suggestions for the current version of the paper, please refer to my comments under Questions.
优点
- The paper performs a number of experimental verifications as well as theoretical proofs.
- The experimental results show that the method is enabled to significantly reduce GPU memory consumption while ensuring a competitive performance.
- Additionally, the appendix contains valuable results.
缺点
- This manuscript does not analyze the computational load of Forward-mode.
- There are still some issues with the experimental setup, such as the need to consider the impact of the number of clients on performance.
问题
- What is the computational overhead situation in each step (e.g., the training time for each gradient step)? A comparison of computational overheads is suggested in the experimental section.
- Is each client's trained layer the same from start to finish? Or are the layers reassigned each communication round?
- It seems that the layers for each client are fixed at the beginning. If we consider the dynamic distribution of layers, it might alleviate the issues caused by heterogeneous devices.
- Communication overhead is an important metric in federal systems, and it is suggested that the authors give the communication data size in each communication round.
局限性
The authors give a limitation analysis in the checklist section.
W1. Computational load of Forward-mode AD and
Q1. Computational overhead
Compute cost of each client and the server is given in “2. Computation costs” under the global “Author Rebuttal”. We also state time per iteration cost here followed by the result analysis:
The computational overhead (in terms of seconds) is measured on Nvidia 1080ti and RTX8000 for RoBERTa Large and Llama2 respectively. For a fixed batch size of 8 is as follows,
| Method | Time per iteration (in seconds) for RoBERTa Large with LoRA on AGNews | Time per iteration (in seconds) for Quantized Llama 2 7B with LoRA on MultiRC |
|---|---|---|
| Backpropagation [First-order] (used in FedAvg, FedYogi, FedSGD) | 0.1207 | 0.0167 |
| Zero-order Finite Difference [Zero-order] (used in FedMeZO) | 0.0843 | 0.0115 |
| Zero-order Finite Difference [Zero-order] (used in FwdLLM+) | 0.7593 | 0.2383 |
| Zero-order Finite Difference [Zero-order] (used in Baffle+) | 2.683 | 1.3175 |
| Forward-mode AD for all weights [First-order] (used in FGD) | 0.3215 | 0.1237 |
| Forward-mode AD for 1/3rd of the weights [First-order] (used in Spry) | 0.1036 | 0.0209 |
To summarize, while backpropagation is faster than forward-mode AD due to the column-wise gradient computation of Pytorch, forward-mode AD proves to be more feasible for edge devices due to its superiority in memory footprint reduction. And finite differences are more unstable, resulting in the methods based on the finite difference having to evaluate 20 to 100 perturbations per iteration, while achieving sub-optimal performance compared to Spry due to numerically unstable gradients. The explanation is as follows:
Compared to backpropagation-based methods and FedMeZO, forward-mode AD is slower than backpropagation due to how the gradients are currently computed (discussed in Section 5.3 of our work, and further elaborated in Section 3.1 of [1]): Forward-mode AD computes jvp column-wise, while its counterpart vjp in backpropagation is computed row-wise. The column-wise computation incurs time overhead if the number of trainable parameters >> output size (loss is scalar, hence the output size is 1) – which is the case for neural networks.
Although Spry’s forward-mode AD is slower than backpropagation and FedMeZO’s finite difference, its chief advantages lie in its much-reduced peak memory consumption (shown in Section 5.2) with comparable time to convergence (shown in Section 5.3). Spry has consumed 27.90% to 86.26% lower peak memory compared to backpropagation because it does not need to store activations of the entire forward pass. And Spry is 1.34-2.98x faster than FedMeZO to achieve convergence due to better gradient approximations.
Compared to zero-order finite difference methods FwdLLM and Baffle, forward-mode AD is faster per iteration. FwdLLM uses a variance control mechanism to pick the perturbations which have smaller variance with the previous round’s aggregated gradient. This adds an overhead in FwdLLM for each iteration to sample and pick the appropriate perturbations.
The zero-order finite difference gradients derived from Baffle, and FwdLLM are numerically unstable due to the accumulation of truncation and round-off errors, resulting in having to take average of the gradients derived across 20 to 100 finite difference evaluations per iteration.
Compared to 20 to 100 evaluations of finite differences per iteration, forward-mode AD can get a better approximation of the true gradient in 1 forward pass per iteration given that it has to perturb and train a fraction of the trainable weights of a large neural network.
[1] Automatic Differentiation in Machine Learning: a Survey (Baydin et al, JMLR 2018)
W2. Impact of the number of clients on performance
Results on scaling up participating client count are discussed in Section 5.4, under “Effects of participating client count” (more details available in Appendix F “Effects of the Number of Participating Clients per Round”).
The experiment involves raising the participating client count from 10 (participation ratio of 0.1) to 100 (participation ratio of 1.0), for a total of 100 clients on the SST2 dataset. We observe that as the participation ratio increases from 0.1 to 1.0, accuracy increases by 2.94%. Furthermore, to achieve an accuracy of ~85%, Spry needs only 150 FL rounds for participation ratio 1.0, compared to 500 FL rounds with participation ratio of 0.1.
This shows that as more clients train the same model parameters through Forward-mode AD, the gradient estimates are closer approximations of the true gradient, leading to a higher accuracy and faster convergence of Spry, which is corroborated in Theorem 4.1.
In fact, as the number of participating clients increase, not only there would be more clients to train each layer (improving accuracy as discussed in the initial paragraph), but each client would have to train fewer layers (reducing computation and communication cost for all the clients).
Q2 and Q3. Layer assignment to the participating clients
Each client’s trained layer is not the same from start to finish. At the beginning of a communication round, trainable layers are randomly assigned to the participating clients of that particular round (Line 6 of Algo 1, Appendix E).
In our cross-device FL setting, we do not make any assumptions on which clients will be participating for a given round, and for our main experiments (in Table 1 of Spry), the participation ratio (#participating clients/#total clients) is 0.01 to 0.1. Therefore, it’s unlikely that a client will get randomly picked for subsequent rounds, and it’s also unlikely that the same client will be randomly assigned to the same layers for multiple rounds.
Hence, as the reviewer has suggested, we are indeed doing layer assignment dynamically.
Q4. Communication overhead
Please see our reply under “1. Communication overhead” in the global “Author Rebuttal”.
Thank you for your responses. I have increased the score from 5 to 7.
We are pleased that our responses have addressed the reviewer’s concerns. We will incorporate their valuable insights to enhance our manuscript further.
We greatly appreciate the reviewer’s feedback, which is instrumental in improving our work.
This paper introduces a forward-mode AD federated learning algorithm (SPRY). They use SPRY to finetune LLMs and demonstrate a low memory footprint compared to backpropagation-based federated learning algorithms. The authors also derive SPRY’s convergence rate and provide theory behind why SPRY’s global gradients are unbiased for homogeneous data distributions across clients. The authors empirically evaluate SPRY on 8 language tasks as well as perform ablation studies.
优点
- The empirical evaluation on eight language tasks while testing multiple different language models is a key strength of the paper.
- The paper is organized clearly with nice figures and a clear structure to the sections. The empirical evaluation section is particularly well structured.
- The ablation studies further strengthen the results of the paper.
- The section of peak memory consumption is clear and highlights the superior memory performance of both zero-order and first-order forward mode AD compared to backpropagation.
- Splitting forward-mode across multiple layers seems a novel way to perform federated learning.
缺点
- The baseline FWDLLM is described as a zero-order-based approach but upon reading this paper, it seems like it uses a first-order forward mode AD update rule in the federated setting (https://arxiv.org/pdf/2308.13894). Why does the paper categorize this approach as a zero-order-based method? Also, Figure 10 in that paper shows a better performance of all the baselines on the different data sets. Why is there such a significant discrepancy? E.g. RoBERTa-large gets a performance of much greater than 80 % for FwdLLM in the original paper for AG News and is over 60 % for Yahoo.
- Equation (4) is challenging to interpret. Why is the summation over classes? What do the subscripts in square brackets mean? Does the paper define data distribution only according to how class labels are distributed among the clients? This did not appear obvious in the paper.
- The implications of theorem 4.2 are not clear. Is this bound related to the variance of the gradient estimator? Generally the variance is related to .
问题
- Are the authors aware of AsyncFGD (https://openreview.net/pdf?id=45RBLZBJid)? They also implement FGD in parallel by splitting weights across workers. They do not reach the model and experiment size of this paper, but there does appear to be some similarity in the approach. Would it be possible to identify these differences?
- A general limitation of FGD as a learning method is the variance of the gradient estimator with increasing dimension of the parameter space. How does SPRY overcome this challenge? Is there a theory that updating individual layers on different clients reduces this variance?
- For per-iteration communication, is there any advantage to having clients, since the server needs to generate all the tangent vectors and then update all the weights of the model. Since there is no matrix multiplication, perhaps there is no need for a GPU on the server-side? Also, generally what would the communication and compute cost be for the server?
- What is the current state-of-the-art in federated learning?
局限性
Authors have done a good job at describing limitations and memory usage of their approach.
W1. Comparison to FwdLLM
FwdLLM[1] shows the equation of finite difference (which involves only function evaluation and no gradient computation) in their paper’s Eq1. And their experiment scripts[2] refer to the user of finite differences as well. Hence, we categorized FwdLLM as a zero-order method.
[1]FwdLLM: Efficient FedLLM using Forward Gradient (ATC 24)
[2]FwdLLM Code: https://tinyurl.com/2cwuhjfu
The results differ from our Table 1 due to variations in participating and total client numbers in each FL round. FwdLLM[1] originally used 100 clients per round, but we used 10 clients due to limited compute resources. Previous studies[3, 4] have also used 10 clients per round. This results in an accuracy gap of approximately 3.06-8.17% for AGNews, Yelp, and Yahoo datasets.
[3]Ditto: Fair and Robust Federated Learning Through Personalization (ICML 21)
[4]Federated Optimization in Heterogeneous Networks (MLSys 20)
W2. Interpretation of Eq4
In Eq4 we quantify the gradient estimation bias caused by heterogeneous data across clients using Dirichlet distribution, a popular method in FL to simulate heterogeneity in classification. It divides data using , higher values indicating more samples of class .
Summation over classes
The bias of gradient estimation is defined by difference of gradient expectations over global data randomness. This difference can be computed per class and summed over all classes to get the total difference, as shown in Eqs. 17, 18 of Thm H.4 proof, Adx H. That’s why we sum over classes.
Subscripts in square brackets
Subscripts under the model params and perturbations show the parameters a client uses to perturb and fine-tune, with weight perturbations being used to fine-tune a weight subset .
Defining data distributions
While our homogeneous data experiments are not limited to classification, it’s common in FL works to define heterogeneity by a Dirichlet distribution on classification tasks[5, 6]. Dir distribution simulates real-world data distributions since a client usually has data distributed disproportionally across classes.
Hence, to express the bias of gradient estimations for theoretical analysis, we used Dir distribution. More details are given in proof of Theorem H.4.
[5]Federated Learning Based on Dynamic Regularization (ICLR 21)
[6]Personalized Federated Learning under Mixture of Distributions (ICML 23)
W3. Implication of Theorem 4.2
Thm 4.2 derives Spry's error bounds based on: a) total comm. rounds, b) perturbation size for forward-mode AD, c) perturbation count per iteration, d) data heterogeneity, & e) number of clients training a subset of weights. The standard form of error bounds in case of non-convex objective is L2 norm of [7,8].
We show that increasing training rounds or number of clients training the same parameters decreases convergence error, while decreasing training parameter count or decreasing data heterogeneity also reduces convergence error.
[7]SCAFFOLD: Stochastic Controlled Averaging for Federated Learning (ICML 20) [8]Adaptive Federated Optimization (ICLR 21)
Q1. Comparison to AsyncFGD
AsyncFGD[9] targets an efficient resource utilization for gradient descent with gradients derived by forward-mode AD (also called “forward gradient descent”) by parallelizing jvp computations across various iterations on a single device. We point out the major differences as follows:
Spry aims to improve memory efficiency in fine-tuning large models through forward-mode AD, while AsyncFGD aims to enhance resource utilization by parallelizing jvp computations across multiple iterations.
Limited model sizes for AsyncFGD
AsyncFGD is designed for a single client and has limited experiments for models with 13M parameters. Applying AsyncFGD to larger models like RoBERTa Large (355M params) for a single client would fail, similar to the failures of FGD in preliminary experiments (Sec 5.4 and Adx F on “The Importance of Splitting Layers”).
AsyncFGD doesn't address high-dimensional parameter space issues, while Spry does by splitting the space across multiple clients. AsyncFGD's resource utilization can be applied to each client in Spry, making them orthogonal.
[9]Accelerated On-Device Forward Neural Network Training with Module-Wise Descending Asynchronism (NeurIPS 2023)
Q2. Variance of the gradient estimator
It is true that variance of estimated gradients increases with param dimension. This led us to the development of Spry, which uses participating clients to train a small set of parameters, reducing the impact of gradient estimation variance faster in FL rounds.
Thm 4.2 shows that increasing the number of clients training the same layers and reducing the trainable parameter count decreases convergence error due to the reduction in the impact of global gradient's bounded variance by smaller perturbation size and larger number of clients training the same layer.
Q3a. Per-iteration communication advantages
Spry's per-iteration communication variant requires clients to derive jvp values from local data, while the server generates tangent vectors using random tensors from a normal distribution. Once clients provide jvp values, the server multiplies the average scalar with the tangent vector. The reviewer points out that there's no need for a GPU on the server side, which is not an assumption for Spry.
Q3b. Server communication and compute cost
Please refer to the global “Author Rebuttal”.
Q4. Current state-of-the-art in FL
Adx A provides an overview of recent methods for fine-tuning large language models in federated settings, as it gains popularity in this context. Due to space limits, we refer to recent surveys[10,11] which encompass progress made on solving FL challenges.
[10] Recent advances on federated learning: A systematic survey (Neurocomp 24) [11] Federated Learning for Generalization, Robustness, Fairness: A Survey and Benchmark (TPAMI 24)
The rebuttal helped remove some of the weaknesses so I will raise my score.
We thank the reviewer for taking the time to provide helpful feedback and for the thoughtful reconsideration of their assessment. We will be sure to incorporate the feedback into the manuscript.
We thank the reviewers for their suggestion on adding information on communication and computation costs, and we will update the manuscript with detailed explanations of the following:
1. Communication overhead
Table 1 of the PDF attached to this response illustrates communication costs of Spry and its backpropagation and finite difference based baselines. A discussion on communication modes of Spry is also given in Section 3.2, “Per-Epoch Communication” and “Per-Iteration Communication”. Here we discuss the costs related to those communication modes:
• Per-epoch Communication
Spry's client-to-server communication cost does not scale linearly with clients like backpropagation and finite-difference counterparts, but instead decreases or stays constant for as more clients are present. Server-to-client communication cost is lower in Spry due to only sending one layer per client when or layers per client otherwise. This result follows from the below observation:
Backpropagation-based and finite-difference-based methods have a communication cost of , where represents the global model size. Each client in (set of participating clients) receives all trainable parameters from the server, requiring the server to send a total of parameters each round.
Spry's communication cost per epoch is , where is the layer count and is the count of parameters for each layer. Each client sends a subset of trainable parameters, incurring a communication cost of parameters for , and for . When , each client gets 1 layer, hence the communication for each client is .
• Per-iteration Communication
Spry accrues lower communication cost than the finite difference and backpropagation counterparts due to the layer splitting strategy, and the server’s ability to compute gradients based on the jvp value. This is because:
The communication cost from client to server for forward-mode AD and finite differences is 1. This is due to an FL round that involves 1. server selecting a random seed, 2. server sending it with trainable parameters to clients, 3. clients generating perturbations based on the seed, 4. deriving and sending back a scalar or finite difference scalar to the server, and then 5. server computing gradients by multiplying the derived perturbations with the seed.
The server to client communication is , where the “+1” is due the randomness seed.
2. Computation costs
Table 2 of the PDF attached this response shows the computation costs of Spry and its baselines, where the client-side cost is for each iteration, and the server-side cost is for each round.
Briefly, Spry’s client-side computation cost is traded off by a faster convergence to higher accuracy through better gradient approximations compared to finite difference-based methods. Furthermore, Spry is the least computationally expensive on the server side due to needing to aggregate fewer parameters from the clients.
Table 2 assumes that matrix multiplication costs for each layer, resulting in a forward pass cost of . The cost of backpropagation is because the computation of the current layer’s weight gradient is , and the cost of computing the previous layer’s activation gradient is another . jvp computation in Spry takes additional cost of for each layer. Moreover, since jvp calculation happens through column-by-column vector multiplications (Section 3.1 of [1]), the related overhead is quantified by .
• Client-side per-iteration computation cost
Backpropagation needs 3 matrix multiplication operations per layer. For zero-order methods. There are 2 matrix multiplications (incurred due to two forward passes) per layer, and per perturbation within a training iteration; and additional overhead for perturbation generation. MeZO requires generation of perturbations thrice for the same seed (Algorithm 1 in MeZO [2]).
Spry’s computation cost is . Since Spry allocates at most layers to each client, the computation cost only scales with , against its counterparts scaling with . However, forward-mode AD computes jvp column-wise, while its counterpart vjp in backpropagation is computed row-wise. This results in time overhead () if the number of trainable parameters exceeds the output size (1 as loss is scalar), which is the case for neural networks. Therefore, Spry's per-iteration computation cost is higher compared to other approaches.
Note that the per-iteration computation cost of Spry is not the whole picture. It takes fewer communication rounds to reach higher accuracy due to better gradient approximation of forward-mode AD than finite difference methods. This is why "Time to Convergence" (Section 5.3) discusses a fair comparison of Spry's runtime and prediction performance.
• Server-side per-round computation cost
On the server side, Spry is the least computationally demanding. Spry needs to aggregate a subset of layer weights from only the clients that were assigned to those layers, while its counterparts need to aggregate all layers from all clients.
Computation cost on the server-side changes based on the communication frequency Per-iteration communication incurs an additional overhead of and (generation of perturbations at the server-side, and multiplying those perturbations with aggregate of the jvp values received from the clients) for Spry and its zero-order counterparts resp.
[1] Automatic Differentiation in Machine Learning: a Survey (Baydin et al, JMLR 2018)
[2] MeZO: Fine-Tuning Language Models with Just Forward Passes (Malladi et al., NeurIPS 2024)
Fine-tuning large language models in federated learning settings has become increasingly important as it allows resource-constrained devices to fine-tune a model using private data. However, fine-tuning LLMs using backpropagation requires excessive memory (especially from intermediate activations) for resource-constrained devices. While Forward-mode Auto-Differentiation can significantly reduce the memory footprint from activations, the authors observe that directly applying it to LLM fine-tuning results in slower convergence and lower accuracy. They introduce SPRY, which reduce the memory footprint, speed up convergence and combine it with LORA, which leads to more scalable FL training algorithms.
Generally, the reviewers were impressed by the numerical results, and found interesting the convergence analysis. Furthermore, it studies an important topic. I thus recommend accepting this paper.