PaperHub
6.0
/10
Poster4 位审稿人
最低5最高7标准差0.7
7
6
6
5
3.8
置信度
COLM 2025

Multi-Token Attention

OpenReviewPDF
提交: 2025-03-21更新: 2025-08-26
TL;DR

We present Multi-Token Attention: new method that allows LLMs to condition their attention weights on multiple query and key vectors simultaneously.

摘要

关键词
Deep learning architecturesLarge Language Model (LLM)TransformerAttention

评审与讨论

审稿意见
7

The paper introduces multi-token attention (MTA) as a replacement for single-token attention in the original Transformers. This approach utilizes convolution operations over queries, keys, and heads. Across various evaluations, MTA outperforms the standard Transformer and DIFF Transformer in language modeling (PPL), popular benchmarks in zero-shot settings, and tasks involving long-range dependencies. An ablation study highlights that group normalization is an important component of this approach.

接收理由

  1. The concept of replacing single-token attention with multi-token attention is compelling to me, and the idea is simple and novel.
  2. There are various tasks, including language modeling (PPL), popular benchmarks in zero-shot settings, and tasks involving long-range dependencies.

拒绝理由

  1. The improvement over the Transformer baseline is negligible. For instance, in Table 2, the average PPL decreases from 11.25 to 11.09, representing an improvement of less than 2%. Similarly, the average score on standard benchmarks in Table 3 increases from 43.7 to 44.4, which is also an improvement of less than 2%.
  2. The experiments are conducted solely on 880M-size models, so it remains uncertain whether the methods would be effective on larger models.
  3. There is no comparison of training and inference costs with the baselines.

给作者的问题

Please refer to the reasons to reject.

评论
  • The improvement over the Transformer baseline is negligible. For instance, in Table 2, the average PPL decreases from 11.25 to 11.09, representing an improvement of less than 2%. Similarly, the average score on standard benchmarks in Table 3 increases from 43.7 to 44.4, which is also an improvement of less than 2%.

We agree that these improvements are small in terms of PPL, but they are still larger than other baselines. We also don’t expect a significant boost on standard benchmarks as they are relatively simple with short context, which won't reflect harder challenges in long context or more challenging tasks. Our main improvements are in Long-range dependency tasks, where the model has to find relevant information in the middle of large context, where our architecture brings significant improvements (for ex, 23% perplexity drop on Lambada standard, from 17.6 to 13.6, or 110% improvement in 4k-context 6-Needle task, from 31.9 to 67.0).

  • The experiments are conducted solely on 880M-size models, so it remains uncertain whether the methods would be effective on larger models.

To address this issue, we have now performed "scaling law" experiments on two models of smaller size - 300M and 550M, - and a larger 1.4B model. Perplexity evaluations are provided here: https://anonymous.4open.science/r/projects-862C/scaling_laws.pdf . We observe consistent performance improvements across models of all sizes. We can not afford training larger models from scratch, so we have also included preliminary experiments on continuous training of Llama model with 1B, 3B, and 8B parameters: https://anonymous.4open.science/r/projects-862C/finetuning.pdf

  • There is no comparison of training and inference costs with the baselines.

Training efficiency is provided in the Appendix: in Table 9 in terms of number of parameters, and in Table 10 in terms of memory and runtime. For inference, we only add a single convolution operation that has O(seq_len x kernel_size) for each token, which is smaller than attention operations that have O(seq_len x head_size) operations. In addition, we also showed that MTA can be added to only 1/4 of layers, significantly reducing the added computation.

评论

The supplemental experiments provided by the author have alleviated my concerns, so I'm willing to raise my score from 6 to 7.

审稿意见
6

This paper proposes Multi-Token Attention (MTA), a modification to the standard transformer attention mechanism that allows attention weights to condition on multiple query and key vectors simultaneously rather than just single token pairs. The authors achieve this by applying convolution operations over queries, keys, and heads, allowing nearby elements to influence each other's attention weights. They test MTA on toy tasks, standard language modeling (pretraining an 880M parameter model on SlimPajama), and long-context tasks like Needle-in-a-Haystack and BabiLong, where they find improvements over standard transformers. The key innovation is allowing the attention mechanism to leverage richer information beyond what a single token vector pair can provide.

接收理由

  • The core idea is quite interesting and addresses a fundamental limitation in transformer attention. The authors clearly identify that standard attention can struggle when "looking up a sentence that mentions both 'Alice' and 'rabbit'" as it "requires the query vector to encode both tokens" (lines 33-35). Their convolution approach seems like an elegant way to add this capability.
  • The motivating toy task effectively demonstrates MTA's benefits, showing how it can easily solve a problem where standard transformers struggle. The task requires finding blocks containing specific letters - something standard attention has trouble with, but MTA solves with "near zero error rate" (lines 170-172).
  • The ablation studies on different MTA components provide good insights. Varying the number of layers with key-query convolution (Fig. 4) shows that even adding just a few MTA layers brings significant improvement. The visualizations of kernel patterns in the appendix are interesting
  • The paper has real improvements on needle-in-haystack and BabiLong tasks that explicitly test long-range dependency capabilities, which is a clear application area where this approach makes sense.

拒绝理由

  • The model size used in experiments (880M) is quite small by current standards. Given that MTA is meant to add expressivity that might otherwise require "increased dimension, and for the model to use lots of its capacity" (lines 37-38), it would be more compelling to see if the benefits persist in larger models (e.g., 7B+) where capacity constraints might be less of an issue.
  • The paper doesn't fine-tune any existing pre-trained LLMs with MTA, making it hard to assess how easily this could be integrated into current workflows. As the authors themselves note in the limitations section, "Multi-Token Attention method is not currently compatible with the popular optimized attention kernels" (lines 428-429), which raises questions about practical adoption.
  • The computational overhead seems significant. Table 10 indicates MTA uses nearly 3.5x more memory than the baseline transformer and achieves only about 20% of the training throughput. The authors don't adequately address whether this tradeoff is worth it - is a 0.16 perplexity improvement (11.25 to 11.09 in Table 2) worth this much additional compute?
  • The paper lacks some important baselines, like sparse attention methods or sliding window approaches that might address similar problems with less overhead. While they compare to DIFF Transformer, more comparisons to other long-context solutions would strengthen the paper.

给作者的问题

  • Could you clarify how MTA would be integrated into models already trained with standard attention? Would this require complete retraining, or could it be added through some form of adaptation?
评论
  • The model size used in experiments (880M) is quite small by current standards...

We perform pre-training that requires an extremely large amount of compute even for a 880M model because of our many ablations and comparison experiments to better understand the method. To address the scalability question, we have performed a series of experiments on models of different sizes and confirmed that perplexity improvements persist (https://anonymous.4open.science/r/projects-862C/scaling_laws.pdf ). We were not able to scale up further given time and compute constraints, but we believe that our scaling laws confirm the generalizability of our approach.

  • The paper doesn't fine-tune any existing pre-trained LLMs with MTA, making it hard to assess how easily this could be integrated into current workflows...

This is a great question! To address it we have continuously trained our 1.4B models, as well as Llama 3 models with 1B, 3B, and 8B parameters with and without MTA on 10B and 5B tokens respectively. Since these experiments were performed in a short period of time, they are not optimized in terms of hyperparameters, but all trained in the same setup (i.e. we would expect longer training to give more improvement in Transformer + MTA architecture compared to continuous training of models without architectural changes). We observe that MTA can be integrated in existing architectures and improve performance on validation sets. These results will be included in the appendix: https://anonymous.4open.science/r/projects-862C/finetuning.pdf For practical adoption of pre-training with MTA for even larger scales, a specialized kernel will be critical for speedup. However, implementing such a kernel is a complex engineering challenge that we leave for future work. We don’t foresee any fundamental reason why a specialized CUDA kernel cannot be implemented for MTA for speedup. The convolution operation in MTA is highly parallelizable, and especially efficient given our small kernel sizes.

  • The computational overhead seems significant…

First, we would like to note that we are comparing naive implementation of MTA with kernel-optimized implementation of multi-head attention. This comes back to the efficient kernels, which at the moment is not implemented for MTA. With efficient kernels specialized for MTA, this computational gap will shrink significantly, but we will leave it to future work. The goal of this paper was to expose a critical limitation of the attention mechanism and show a simple solution for it, backed by experimental results. Second, perplexity alone is not sufficient to measure model performance: while performance improvement on validation perplexity is small, gains on long-context retrieval tasks are quite significant (i.e. ~2x accuracy improvement on multi-needle retrieval tasks).

  • The paper lacks some important baselines, like sparse attention methods or sliding window approaches ...

The focus of our paper is the single-vector bottleneck of attention, which exists in normal attention as well as in sparse and sliding window attention. Sparse attention mechanisms often focus on reducing the computational cost by skipping computing some attention weights. Similarly, sliding window attention restricts the attention windows for improved computational efficiency. Both methods work by masking out some attention weights, which does not address the fundamental issue of attention weights conditioning only on similarity of a single vector. For this reason, we do not consider them for baselines, instead focus on a similar method that augments attention to condition on richer information such as Diff Transformer.

  • Could you clarify how MTA would be integrated…

Great question! We have now performed additional experiments that show MTA can be added as additional layers to already trained models, and the weights can be updated with continual training. Since the main component of MTA is a convolution, we can initialize it to identity and insert to existing a Transformer layer. Such a modification will not change the output of the layer, so when we start training the model, it should maintain its performance. However, as the added convolution starts deviating from the identity state, it will allow the transformer to condition its attention on multiple keys and queries. With our experiments on continuous training we show that MTA can be integrated into models already trained with standard attention without complete retraining: https://anonymous.4open.science/r/projects-862C/finetuning.pdf. We observe consistent improvements in validation perplexity across all models. Since these are experiments performed for the purposes of rebuttal, there are further directions that could be explored here, but we believe they show the feasibility of this approach and alleviate training speed concerns. These experiments will be added in the appendix.

评论

Thank you for your response. It mostly fixed my concerns, and I would like to keep my positive ratings.

审稿意见
6

The paper proposes a new attention mechanism called Multi-Token Attention (MTA), aiming to address the limitations of traditional single-token attention mechanisms in precisely locating relevant information within context. The experimental results demonstrate that MTA outperforms baseline models across multiple benchmarks, particularly in long-context tasks, highlighting the effectiveness of the proposed method.

接收理由

  1. The method presented in this paper has a clear motivation and aligns well with the context of graph neural networks. Currently, attention scores are calculated based only on single-hop (directly connected) query-key pairs. However, by using convolution or other methods, multi-hop features can be incorporated.
  2. The experimental results highlight two significant advantages: lower perplexity in language modeling and better performance in long-context understanding.
  3. The additional parameters introduced are negligible.

拒绝理由

  1. While the motivation of this paper is clear, the technical innovation is somewhat limited. Beyond validation through ablation studies, there is a lack of substantial theoretical or analytical support. Although the effectiveness of multi-hop makes sense, the paper does not discuss whether CNN kernels and group norm are the only ways to achieve this, or if other methods could be equally effective.

  2. I observed that in the language modeling tasks, MTA with group norm achieves better performance (Tables 2, 3, 4), while MTA without group norm performs comparably to the baseline. However, in the long-context retrieval tasks, MTA without group norm shows better performance (Table 5, Fig. 7), and MTA with group norm performs comparably to the baseline. Could the authors explain the reasons behind these results?

  3. As an improvement work based on the Transformer architecture, this paper lacks evaluation regarding scalability. The baseline Diff. Transformer has demonstrated scalability from 830M to 13B parameters, whereas this paper discusses only a very limited range of model sizes.

  4. Although the authors discuss efficiency issues in the Limitations section (memory and computational efficiency being only 1/5 of the baseline), attributing it to the absence of efficient operators, I would like the authors to discuss whether MTA is compatible with online softmax optimization algorithms. If not, this would represent a significant drawback of the method, implying that under fair computational efficiency comparisons, MTA should be compared to a baseline that is five times larger in size.

给作者的问题

  1. In the toy task experiment presented in Table 1, does the Transformer architecture align with that used in other sections of the paper (i.e., is it a llama-like architecture rather than a naive Transformer-Decoder)?

  2. Why is the DIFF w/o group norm. baseline included in Table 2 but absent from other experiments?

  3. I suggest noting that this paper discusses the Transformer-Decoder model, and that the complete Transformer also includes the Encoder.

评论
  • While the motivation of this paper is clear, the technical innovation is somewhat limited…

Our method is first to apply a convolution operation on the attention maps to combine information from multiple queries and keys. This has never been done before and is technically novel. Providing a theoretical proof is very challenging for such complex non-linear architectures with complex training dynamics. Instead we provide experimental proofs supporting our method, which is commonly used for this type of research. In addition, we do provide motivation for our method emphasizing the bottleneck of single-vector attention, and provide experimental analysis using ablation experiments. Using multiple key and query vectors to compute an attention weight has a clear advantage over using only a single vector because the amount of information that can be compressed into a single vector has limitations.

  • I observed that in the language modeling tasks, MTA with group norm achieves better performance…

Thank you for this question! Since we observed that there are a few key-query kernels that are responsible for solving long-context retrieval tasks (Fig 3), and these kernels are in the middle layers of the transformer, we hypothesise that adding groups normalization tones down the effect of these kernels by normalizing them with other kernels that are less important, and thus hurting the performance. While on the short-context tasks that require collaborative work of multiple kernels, normalization helps to get a more uniform signal. Motivated by this, we have now experimented with replacing depth scaling with a sigmoid gating mechanism that could potentially serve as a “switching” mechanism for kernels. Our experiments confirmed that sigmoid gating helps to take the best of both worlds and shows better performance across all tasks. We have updated our paper accordingly and added these results to the ablations: https://anonymous.4open.science/r/projects-862C/ablations.pdf
The final average accuracy on the 2k-context Needle task after incorporating all changes has increased from 73.63 to 89.70, and perplexity on Lambada decreased from 11.18 to 10.82.

  • As an improvement work based on the Transformer architecture, this paper lacks evaluation regarding scalability…

Due to compute limitations we couldn’t afford pre-training 13B models. We however agree that for new architectures it is important to study scaling laws, thus we added two models of smaller size - 300M and 550M - as well as 1.4B model. Perplexity evaluations are provided here: https://anonymous.4open.science/r/projects-862C/scaling_laws.pdf . We observe consistent performance improvements across models of all sizes. We can not afford training larger models from scratch, so we have also included preliminary experiments on continuous training of Llama model with 1B, 3B, and 8B parameters: https://anonymous.4open.science/r/projects-862C/finetuning.pdf

  • Although the authors discuss efficiency issues in the Limitations section ...

Any major modification to the attention layer will be slower because it cannot take advantage of the specialized CUDA kernels that are hard-coded for the standard attention. But this shouldn’t prevent us from exploring new attention architectures that have stronger capabilities. We don’t foresee any fundamental reason why a similar specialized CUDA kernel cannot be implemented for MTA for speedup. The convolution operation in MTA is highly parallelizable, and especially efficient given our small kernel sizes. The number of added floating operations with MTA convolution is O(seq_len^2 x key_kernel x query_kernel), which is smaller compared to attention computation that has O(seq_len^2 x head_size). The group norm operations have only O(seq_len x head_dim) complexity, so they have minimal overhead.

  • In the toy task experiment presented in Table 1, does the Transformer architecture align with that used in other sections of the paper (i.e., is it a llama-like architecture rather than a naive Transformer-Decoder)?

Yes we use a similar llama-like architecture that is smaller in size. We will make it clear in the paper.

  • Why is the DIFF w/o group norm. baseline included in Table 2 but absent from other experiments?

We show that removing group normalization is responsible for half of the performance gain of the DIFF transformer. The original paper on DIFF transformer does not perform ablation on normalization, our motivation was to show how important this part is in this architecture. We agree it might be confusing to have it in one table, we will keep it in the ablation section only and remove it from Table 2.

  • I suggest noting that this paper discusses the Transformer-Decoder model...

Yes, that is correct. We used the term “Transformer” instead of “Transformer-Decoder” for its brevity and common usage in similar literatures. We will make it clear in the paper.

评论

Thanks for the authors' response. Some issues have been resolved, but several of my questions remain unsolved. I would be happy to increase the score if the author can address these.

This has never been done before and is technically novel.

The calculation of attention scores between queries and keys can be analogous to computing link prediction scores between two nodes in graph neural networks. Many studies[1][2], especially in graph attention networks (which bear a strong resemblance to Transformers), have enhanced node representations by aggregating local information such as subgraphs and neighbors. I would like to know the core technical differences between this paper and such works. If I understand correctly (regarding Eq 4), MTA can be seen as a feature enhancement for nodes within the <c_q, c_k> local window (i.e., a fully connected subgraph).

[1] Multi-hop Attention Graph Neural Network (IJCAI 2021)
[2] How powerful are k-hop message passing graph neural networks (Neurips 2022)

added these results to the ablations: https://anonymous.4open.science/r/projects-862C/ablations.pdf

I am unable to see the connection between these ablations and my questions. Could the author please provide a more specific explanation?

increased from 73.63 to 89.70, and perplexity on Lambada decreased from 11.18 to 10.82.

Could you list the results for Needle=2/4/6 separately? How about the results for 4K context in Table 5? I seem unable to locate the source of 11.18 in the paper.


regarding the response on scalability

I understand that the model size range in the current experiments is acceptable. However, it appears that the gains from MTA diminish in larger models.


efficiency kernel

The absence of a CUDA kernel implementation for MTA is acceptable to me. I would appreciate it if the author could describe the implementation process of MTA under the online softmax algorithm in pseudocode, similar to the "Implementation with FlashAttention" subsection in Diff. Transformer (page 15).

评论

Thank you for your prompt response. Please find our replies below:

  1. Indeed, there are similarities between attention scores and the link prediction scores in graph networks, which share some architectural similarities with Transformers. In that sense, our approach similarly tries to aggregate local information from its neighbors. One primary difference lies in the way we define the local window for feature enhancement: we look at the local neighborhood in time, such as a sentence, which is hard to imagine in graph neural networks. Many graph-based methods rely on predefined graph structures, such as adjacency matrices, to define the local neighborhood. Our approach, on the other hand, takes advantage of the sequential nature of the input and tries to aggregate information from temporally closer tokens. Furthermore, our MTA method is designed to handle long-range interactions in sequential dependencies. In this regard, our approach differs from traditional graph-based methods, which often assume a fixed graph structure and may not be well-suited to handle such complexities.

In addition, the main goal of [1] seems to be that it “computes the attention values between pairs of nodes that are not directly connected by an edge”, which makes sense for GNN, but is not applicable to Transformers because it’s a fully connected graph (except causal masking). Similarly, [2] also focuses on K-hop message passing, which means nodes that are connected by K-edges. Again this is not easily transferable to Transformers because it’s fully connected and all nodes are directly connected.

  1. To answer your question “why MTA with group norm is better on regular tasks, while MTA without group norm was better on long-range retrieval”, we provided a hypothesis that normalization helps to get a more uniform signal on regular tasks, while toning done a very few specialized kernels that help to solve long-range dependency tasks. To take the best of both worlds, we replaced depth scaling with a sigmoid gating mechanism that could potentially serve as a “switching” mechanism for kernels. The results of these experiments are added to the ablation table in the second block: in terms of validation perplexity, MTA w/o group norm 11.03, MTA 10.99, MTA with scalar gating 10.95. So it is possible to get the best of both worlds, which we think is exciting!
  2. MTA model with 2k-context: 2 needles: 89.6, 4 needles: 89.3, 6 needles: 90.2. 4k context: 2 needles: 89.6, 4 needles: 93.7, 6 needles: 84.5 (89.3 average). As for perplexity, slight typo in our last reply, sorry for the confusion. We repeat the whole thing here with added explanation: "Specifically, average accuracy on the 2k-context Needle task after incorporating all changes has increased from 73.63 to 89.70, and perplexity on Lambada decreased from 11.15 to 10.82. The 11.15 is an average of 13.6 and 8.7 from Table 4.”
  3. The % improvement continues to be better than DiffTransformer's gain over the baseline for all scales, and is relatively constant (between 2.5-3%, but yes it goes up and down a little bit at each sample, some variance there).
  4. Specifically the online softmax algorithm provides a way to accelerate Softmax computation compared to a naive implementation with fewer memory accesses [1]. Our algorithm does not interfere with softmax operation, i.e. convolution operations are applied before and/or after softmax, so the pseudocode would be the same except for this change. That is, we can directly apply algorithm 3 from [1] where <x_j> will be the outputs of the MTA convolution operation.
评论

Thank you for the authors' response. While it prompted my decision to raise the score from 4 to 6, I have limited the increase for the following reasons:

  1. The MTA exhibits significant conceptual overlap with node enhancement techniques in GNNs. GNNs routinely handle dynamic, weighted, fully-connected graphs, which substantially diminishes the perceived technical novelty.

  2. Although the authors introduced a novel normalization technique during rebuttal, significantly mitigating MTA's weakness in long-context retrieval tasks compared to MTA w/o Norm, a performance deficit remains under specific configurations:

  • 2K Context: 2 needles: 89.6 vs. 95.3; 4 needles: 89.3 vs. 92.6

  • 4K Context: 2 needles: 89.6 vs. 94.4; 4 needles: 93.7 vs. 96.3

    As foundational LM architecture research, these gaps undermine community confidence in MTA's effectiveness.

  1. The description of the efficient CUDA implementation remains insufficiently detailed.
审稿意见
5

This paper proposes Multi-Token Attention (MTA), which conditions attention weights on multiple query and key vectors via convolutions across tokens and heads. It extends standard full attention by incorporating a pre-softmax convolution, head-wise mixing, and an additional group normalization layer. Experiments on both short-context and long-context benchmarks demonstrate consistent improvements with the proposed attention variant.

接收理由

  1. The paper is clearly written and easy to follow.

  2. The proposed technique is easy to plug into standard attention.

  3. The proposed attention variant consistently improves performance across various benchmarks.

拒绝理由

  1. The first concern is that the example used to motivate multi-token attention is not convincing enough. After the current token-mixing operation, the information from multiple tokens is naturally fused and can be processed by the subsequent FFN layers and the next attention layer. It is unclear why we need to explicitly operate on the attention map of the current layer. The example and explanation only hold when there is a single attention operator in the model.

  2. Adding convolution on top of Q, K, and V is common in previous language models, especially in linear attention models like GLA, DeltaNet, and Gated DeltaNet. The difference is that this work introduces convolution after the matmul operation in attention, yet no benchmark or explanation is provided to justify why this is a better solution.

  3. The effectiveness of head mixing is not ablated. Conceptually, the subsequent FFN layers can also learn to average across different heads. It is unclear why weighted averaging of two consecutive heads is beneficial.

  4. The efficiency of the proposed attention variant is not analyzed, especially considering that the introduced convolutions and group normalization may introduce additional overhead.

给作者的问题

In addition to the attention map of one example, could you also show the difference in averaged attention maps?

Other questions have been included in the above weakness section.

评论
  • The first concern is that the example used to motivate multi-token attention is not convincing enough. After the current token-mixing operation, the information from multiple tokens is naturally fused and can be processed by the subsequent FFN layers and the next attention layer. It is unclear why we need to explicitly operate on the attention map of the current layer. The example and explanation only hold when there is a single attention operator in the model.

Our motivation is that the attention operation is responsible for locating the useful information from context, but has flaws. For example, let’s say there is a critical information “A is B” in the context. If the attention cannot focus on this information and instead attends to irrelevant tokens, the output from the attention layer will not contain “A is B” information. Since FFN operates on each token separately, it can only process what is output from the attention and cannot retrieve “A is B” information from the context. The reason why the attention has a hard time locating “A is B” is because it relies on the similarity of a single vector against another. Since the amount of information that can be compressed into a single vector is limited, this becomes a bottleneck. Crucially, this bottleneck exists in every layer, so multi-layers have the same problem. Our experiments clearly show that this is in fact a problem with Transformers when it needs to locate a "needle task" sentence like “The magic number of San-Francisco is …”. By combining different attention maps through convolution operations, MTA allows the model to condition where to “attend” on multiple vectors, thus making it much more fine-grained and based on richer information.

  • Adding convolution on top of Q, K, and V is common in previous language models, especially in linear attention models like GLA, DeltaNet, and Gated DeltaNet. The difference is that this work introduces convolution after the matmul operation in attention, yet no existing benchmark or explanation is provided to justify why this is a better solution.

While the same convolution operation is used, there is a critical difference between applying to attention maps vs Q, K, V vectors. Applying convolution to Q, K or V is not going to resolve the single-vector bottleneck problem, which is the main focus of our paper. The attention weights are still computed based on the similarities of only single vectors, even if we apply convolution to QKV beforehand. This is why we propose to apply convolution on the attention maps instead so that the final attention weights can condition on multiple Q and K vectors. We do not consider linear attention in our paper and instead focus on a normal softmax attention, which is the mainstream method for the most powerful released or deployed LLMs.

  • The effectiveness of head mixing is not ablated. Conceptually, the subsequent FFN layers can also learn to average across different heads. It is unclear why weighted averaging of two consecutive heads is beneficial.

Thank you for pointing this out, we have now added an ablation study on the head convolution size, c_h, and found that increasing c_h leads to further improvements in performance. We will update the paper accordingly. Number of heads within the convolution window (c_h, from 0 to all heads with pow(2) increment) can be found here: https://anonymous.4open.science/r/projects-862C/head_ablation.pdf

  • The efficiency of the proposed attention variant is not analyzed, especially considering that the introduced convolutions and group normalization may introduce additional overhead.

Any major modification to the attention layer will be slower because it cannot take advantage of the specialized CUDA kernels that are hard-coded for the standard attention. But this shouldn’t prevent us from exploring new attention architectures that have stronger capabilities. We don’t foresee any fundamental reason why a similar specialized CUDA kernel cannot be implemented for MTA for speedup. The convolution operation in MTA is highly parallelizable, and especially efficient given our small kernel sizes. The number of added floating operations with MTA convolution is O(seq_len^2 x key_kernel x query_kernel), which is smaller compared to attention computation that has O(seq_len^2 x head_size). The group norm operations have only O(seq_len x head_dim) complexity, so they have minimal overhead.

  • In addition to the attention map of one example, could you also show the difference in averaged attention maps?

Averaged attention maps only show the general coverage of the attention head (e.g. short-term vs long-term), and lacks any detail and specific information about if the attention correctly located the useful information, which is the focus of our paper. We don’t expect there to be much difference in averaged attention with MTA given the convolution operation is not going to change the general coverage of attention that much.

评论

I thank the authors for providing the response. However, I still have the following concerns:

  1. For the example "A is B," the analysis holds for a single attention layer but not necessarily for a multi-layer transformer model. Whether understanding "A is B" should be achieved within one layer or distributed across layers warrants more careful analysis. The authors are expected to provide further analysis to convince readers why this should be performed in a single layer.

  2. Regarding convolution on attention maps, the authors mentioned that "applying convolution to Q, K, or V is not going to resolve the single-vector bottleneck problem." I do not fully agree with this, as applying convolution to Q, K, and V can already integrate information, e.g., from "A is B", which can be viewed as forming a new token like "A-is-B." I do not see a critical difference here, and the authors need to present quantitative results to demonstrate the advantage of their approach.

  3. Efficiency is an indispensable consideration when proposing any fundamental operator. It is necessary to justify whether the increased computation or overhead is better compared to simply stacking more traditional operators.

Based on these concerns, I will keep my original rating. Addressing the above issues could strengthen the next version of this work.

评论

Thank you for raising your concerns. Below we provide our reply:

  1. Our paper is a novel algorithmic and empirical work, not a theoretical work, and we provide experiments across a variety of datasets with extensive ablations and comparisons of different kinds of attention mechanisms (including new ones in the rebuttal addressing reviewer concerns) showing clear effects, which is well within the scope of the call for papers for the conference. We therefore think theoretical analysis, while appealing, could easily be followup work perhaps from other authors who are experts in this domain.
  2. As you said, convolution on QK or attention of lower layers can encode “A is B” into a new token vector. But this way, the information “A is B” must be compressed into a single vector, so it can be used in computing attention weights. While such compression might be possible for simple information, it restricts how much information can be used for attention. This is what we are calling the single-vector-bottleneck. In contrast, MTA allows information of “A” and “B” to stay in separate vectors, so they can be compared separately to their corresponding query vectors before being merged into a single attention weight.
  3. We performed new experiments described in the rebuttal showing how MTA blocks can be incorporated into pre-trained transformer models to improve (training) efficiency, and we would appreciate taking that into account wrt your concerns. In any case, for new approaches, we also believe some leeway should be given in terms of efficiency if the direction looks promising in terms of improving performance, such as ours. In this way the community can make progress together exploring efficiency for this new approach, as has been done for example for standard attention.
评论

Thank you to the authors for providing further feedback. I agree that this work introduces some new knobs for LM architecture design, but I believe that for fundamental architectures, we should be more careful and require rigorous validation to demonstrate consistent and non-trivial improvements in either efficiency or accuracy. I will raise my score to 5 to acknowledge the authors' efforts and will discuss further with the other reviewers.

评论

Dear Reviewers,

We would like to thank you for the time and consideration you dedicated to reviewing our submission. Your thoughtful feedback was invaluable to us, and we appreciate the effort you put into evaluating our work.

We have made several additions and improvements to our paper based on your feedback, in particular:

  • There were questions regarding efficiency. To understand if/how our approach can be incorporated in the fully trained models, we have now performed some additional experiments on continuous training of 1B, 3B, and 8B models, where we add MTA layers in the same setup as in main experiments after the model was already trained with a conventional architecture. We show that at this scale continuous training on small amounts of tokens yields improvements in performance: https://anonymous.4open.science/r/projects-862C/finetuning.pdf . Since these are experiments performed for the purposes of rebuttal, there are further directions that could be explored here, but we believe they show the feasibility of this approach and alleviate training speed concerns. These experiments will be added in the appendix.

  • We also trained additional models with 300M, 550M, and 1.4B parameters to address the scalability scaling laws question raised by AZNg, 3NEq and maCc (https://anonymous.4open.science/r/projects-862C/scaling_laws.pdf ). We observe consistent patterns in terms of perplexity improvements across model sizes. We can not afford training larger models from scratch, so we have also included preliminary experiments on continuous training of Llama model with 1B, 3B, and 8B parameters as described above.

  • To address additional ablation points regarding head mixing raised by reviewers ryV8 we have also added additional ablation results on:

    In these ablations we found that increasing the convolution window in head mixing further improves performance of the MTA models. We have updated the paper to reflect these findings.

  • To answer the question regarding the group norm effect raised by AZNg we further looked at different normalization mechanisms and added them to ablations in Table 6. We found that replacing depth scaling with a gating mechanism allows to better activate task-specific kernels and improves performance across all tasks. Specifically, average accuracy on the 2k-context Needle task after incorporating all changes has increased from 73.63 to 89.70, and perplexity on Lambada decreased from 11.18 to 10.82.

最终决定

This paper proposes a new attention mechanism to address the limitation of single token attention.

Initially, the scores were two 4 and two 6 with some major concerns.

During the rebuttal, the authors successfully addressed the issues and the scores becam 7, 6, 6, and 5.

AC carefully read the paper, the reviewers' comments, and the authors' feedback.

The authors response seem to sufficiently address most major concerns raised by a negative reviewer.

So, AC recommends accepting this paper, and asks the authors to reflect the comments in the final version.