Efficient Hybrid Long Sequence Modeling with State Space Augmented Transformers
We propose SPADE, an efficient SSM-attention hybrid framework for long sequence modeling.
摘要
评审与讨论
This work aims to reduce the quadratic computational costs of Transformer models by replacing the first layer by a hybrid layer which combines a state space model with a local attention module. In all other layers, the full attention module is replaced by a local attention module. Experiments have been conducted on different tasks leveraging different architectures and model sizes. Results show consistent improvements over different baselines.
接收理由
- Simple and effective method
- Experimental results on different tasks and model architectures
拒绝理由
- Missing implementation details for the encoder-decoder architecture
- No larger scale experiments (but claiming to “demonstrate the scalability of the proposed method”)
- Related work in the appendix
给作者的问题
- While appreciating the different model architectures, it would have been interesting to see results with a larger decoder-only model and/or results with SPADE_{large,xl}++
- MEGA also works for the machine translation task; since you trained an encoder-decoder model with your approach, results on machine translation would be interesting, too; in particular, since Vardasbi et al. (EAMT 2023) reported mixed results when employing S4 on this task
- Please move the related work section into the main paper
Thank you for the valuable feedback. Regarding the questions, we address the following:
-
The cross-attention is implemented in the conventional way, i.e., no local attention. This is because the output length is usually small for natural language generation tasks. Therefore, the full cross-attention does not introduce much computation latency.
-
Encoder-decoder models are widely used in LLMs, such as T5, FlanT5 and ChatGLM. We leave exploration of decoder-only models as future works.
-
Thank you for the suggestion. We remark that S4+full-attention is not significantly better than S4+local-attention. This is because the SSM layer can already extract global information, such that we only need fine-grained local information on top of the SSM layer for good performance.
Listops Text Retrieval S4+Full 60.85 90.91 91.37 S4+chunk 60.50 90.69 91.17 -
The authors believe that the current extensive experimental results are sufficient to demonstrate the effectiveness of the proposed method. We show results of long sequence modeling, language modeling, natural language understanding and natural language generation in the experiments. In all the tasks, the proposed model outperforms the respective baselines.
-
We make the design choice because local information is an effective supplement to the global information extracted by the SSM. To see it in a different way, SPADE is a local attention Transformer model, except that we add global information to the bottom layer.
-
We put related works to the appendix because of space constraints. We will move this section to the main text.
Thanks for your rebuttal!
- Cross-Attention: Can you be more specific on "output length is usually small" and "does not introduce much computation latency"?
- "S4+full-attention is not significantly better than S4+local-attention": Please provide details of your applied significance test.
- An ablation of the positions and number of different attention modules is still needed to verify the design choice.
Thank you for the suggestions. We would like to address the following:
-
The computation cost of the cross-attention layer depends on both the input (source) sequence length and output (target) sequence length, i.e., the attention matrix is not a square matrix. Because the output sequence length is usually short compared with the input sequence length, the computation cost of the cross-attention is only a fraction of the computation cost of the input self-attention. For example, the input length can be up to 16k in the experiments, while the output sequence length is only 128. That is, for the input self-attention, the size of the attention matrix is (16k * 16k), while for the cross-attention, the size of the attention matrix is only (16k * 128).
-
Here are some results of S4+Full-attention. We can see that performance of S4+Full-attention is only slightly better than S4+Local-attention, but the computation cost is much higher.
Listops Text Retrieval S4+Full 60.75 90.81 91.20 S4+chunk 60.50 90.69 91.17 -
The ablation experiments of the position and number of attention modules are in Section 6.3 (Figure 6). By design, we consider a fixed-size (local attention) Transformer model and add SSM layers to the model. In Section 6.3, we illustrate why SSM layers have to be added to the bottom of the model, and we also demonstrate that only one SSM layer suffices to extract useful global information.
Thanks for the clarification and pointing me again to Section 6.3, I will raise my score to 6!
This paper focuses on improving the computational efficiency of Transformers. Designing an equally expressive model with a much more efficient handling of long context can have a great practical impact.
Specifically, the authors propose to employ a state-space model at the bottom layer of a Transformer to capture long-range and positional information coupled with efficient local attention for the rest of the layers. The method is well-motivated and clearly described; its main novelty is in designing a hybrid model that address the limitations of individual Transformers and SSMs.
The evaluation is comprehensive, covering a diverse set of tasks including long-range tasks, language modeling, and language model pretraining followed by finetuning on NLU and NLG tasks. Comparisons against strong baselines such as S4 and MEGA demonstrate the effectiveness of the approach. One caveat may be that the experimental comparison with different baseline methods is not with the exact same experimental setup and resources (i.e. head-to-head).
Overall, the paper is clearly written and easy to follow. It is also positioned well with respect to prior work and provides sufficient details and artifacts for reproducibility that can enable future work.
接收理由
- It proposes a simple hybrid architecture that combines state-space models with local attention to develop a hybrid model that addresses the limitations of individual Transformers and SSMs.
- The proposed method outperforms competitive baselines on a diverse set of tasks including efficient Transformer architectures such as S3, MEGA, and Hyena.
- Apart from the promising empirical results, it also provides insights and justifications regarding the design choices through analysis.
拒绝理由
- A direct head-to-head comparison with other efficient architectures is missing. This means running the models instead of comparing to the scores from the corresponding papers with comparable configuration.
- The positioning with respect to prior work needs some effort, especially compared to methods that combine global and local information.
给作者的问题
- Is it possible to convert an existing Transformer model into a SPADE architecture instead of training it from scratch? That would be useful for impacting already trained models.
Thank you for the valuable feedback. Regarding the questions, we address the following:
-
All results in Table 2, Table 3 and Table 4 are based on our own implementation instead of directly citing results from other papers. For example, in Table 3, we cite results in the original paper, and we also re-implement the model and report re-implemented results (which are better than the results in the original paper).
-
There are several concurrent works [1,2] that share similar philosophy with SPADE. We will add associated references and discuss them in the next version.
-
It is an interesting direction to explore whether we can convert a trained model into a SPADE model. We leave this direction as future work.
[1] https://arxiv.org/abs/2402.18668
[2] https://arxiv.org/abs/2402.19427
Thank you for the replies and the additional effort. Overall, I am positive about the work but I still have a concern regarding comparison to other models.
Please explicitly mention the re-implementation point and effort involved about its correctness because it's not clear in the paper. It'd be also useful to describe in detail the model differences and provide information that shows that the experimental setups are comparable in terms of hyper-parameter budget and model size.
Note that "head-to-head" comparison refers to using the exact same architecture and varying only a few components e.g. the global attention layer (this can be softmax attention, RNN attention, or any other efficient attention from the literature).
In this paper, authors proposes the SPADE, aiming to combine SSMs with local attention to achieve a good tradeoff between long context understanding (capturing the global information) and computational complexity. Comprehensive experiments are conducted to show the effectiveness of their methods.
接收理由
- This paper is well motivated and the paper is easy to follow and understand.
- SPADE seems pretty strong compared to other similar approaches on some standard benchmarks including understand and generation tasks.
拒绝理由
- Can we say more about the philosophy to design this architecture. for example, why not add another global layers after some local layers, .e.g, global1 -> local1->global2->local2?
- can we list the number parameters and latency for models listed in Table 1?
Thank you for the valuable feedback. Regarding the questions, we address the following:
-
There are many design choices available to combine local and global information. Our intuition is the following: local information extractors such as local attention are more efficient than global information extractors such as SSMs and full-attention. Therefore, we want as few global information extractors as possible. In Section 6.3 and Figure 6, we justify that SPADE is an effective design. We leave exploration of other design choices as future works.
-
The model details are in Appendix C and latency comparisons are demonstrated in Figure 5. We will add the number of parameters and detailed latency.
Thanks for your reply. I overall think this is a good work and will keep my score unchanged.
This paper proposes SPADE, a hybrid architecture that consists of a global SSM layer followed by local attention layers. A variety of experiments are conducted with the architecture on long range tasks as well as standard language modeling tasks with the goal of demonstrating better efficiency with similar or better performance as compared to either SSMs or Transformers alone.
接收理由
-
The paper approaches an important problem of how to gain benefits of transformers and SSMs. And the proposed approach makes sense at a high level to combine global SSM layers with local attention layers.
-
The experiments generally show strong performance for the proposed method as compared to the baselines presented in the paper, especially for the long range tasks considered.
拒绝理由
-
There is a issue with discussion of recent related work. Both Based [1] and Griffin [2] went on arxiv with similar ideas before the colm deadline. They are close enough that they may be counted as concurrent, but they should definitely be addressed. There is also no reference or comparison to more recent state space models from the last year like mamba [3], all the SSMs referenced are a few years old.
-
I am also worried about the choice and strength of baselines. First, there is no comparison to a flash-attention based transformer baseline, which resolves a lot of issues with transformers in terms of efficiency, especially at the <10k sequence lengths considered in the paper. I am not totally convinced that the proposed architecture would give substantial efficiency gains at these lengths and model sizes. Second, especially for the language modeling tasks there is a lack of recent baselines.
-
I also would have liked to see experiments with decoder only causal models. Can the authors elaborate on why they chose to focus on encoder-decoder models?
[1] https://arxiv.org/abs/2402.18668
[2] https://arxiv.org/abs/2402.19427
Thank you for pointing out the references. We will add the references to our paper. Regarding the proposed baselines, we would like to address the following:
-
We would like to highlight that as stated, [1] and [2] should be treated as concurrent work instead of baselines. We will acknowledge the papers and add the references to our paper. For [1], we highlight that from Table 1 in [1], the method performs worse than the standard Transformer with a significant gap. From Table 2 in our paper, SPADE has on par or even better performance compared with the standard Transformer.
-
For SSM baselines, we demonstrate in Table 2 that SPADE is a general framework to combine SSMs with attention. Existing SSMs benefit from the proposed framework. We will update and add results with Mamba [3].
-
We highlight that even with flash-attention, full attention is still significantly slower than local attention, especially when the local window size is small. We will update and add results for efficiency comparison with flash-attention.
-
Encoder-decoder models are widely used in LLMs, such as T5, FlanT5 and ChatGLM. We leave exploration of decoder-only models as future works.
[1] https://arxiv.org/abs/2402.18668
[2] https://arxiv.org/abs/2402.19427
[3] https://arxiv.org/abs/2312.00752
Thanks for the response and for adding the related work. I will raise my score to a 6 to reflect this change. Since the other comparisons have not been completed yet, I will not raise further for now.
The paper proposes a model architecture that combines a global layer based on SSM, followed by local layers using local attention (and MLPs) to model long sequences. The resulting architecture is evaluated on Long Range Arena (LRA), language modeling on wikitext-103, and an encoder-decoder model for natural language understanding and generation. The model shows better performance than Transformer and S4 on LRA, and slightly outperforms the standard Transformer on wikitext-103 and T5 on GLUE.
- The model architecture is simple and the presentation is clear so it is easy to understand.
- Strong performance on LRA.
- Thorough analysis and ablations to justify the model architecture design.
- Language modeling evaluation on wikitext-103 and GLUE are relatively small datasets, where regularization may play a large role, compared to large datasets currently being used to train LLMs. Overall the paper is a valuable contribution to the design of hybrid models, incorporating local attention and SSMs.