PaperHub
5.5
/10
Poster4 位审稿人
最低5最高7标准差0.9
7
5
5
5
3.8
置信度
正确性2.3
贡献度3.0
表达2.5
NeurIPS 2024

Selective Attention: Enhancing Transformer through Principled Context Control

OpenReviewPDF
提交: 2024-05-14更新: 2025-01-21

摘要

The attention mechanism within the transformer architecture enables the model to weigh and combine tokens based on their relevance to the query. While self-attention has enjoyed major success, it notably treats all queries $q$ in the same way by applying the mapping $V^\top\text{softmax}(Kq)$, where $V,K$ are the value and key embeddings respectively. In this work, we argue that this uniform treatment hinders the ability to control contextual sparsity and relevance. As a solution, we introduce the Selective Self-Attention (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy. By controlling temperature, SSA adapts the contextual sparsity of the attention map to the query embedding and its position in the context window. Through theory and experiments, we demonstrate that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries. We also incorporate temperature scaling for value embeddings and show that it boosts the model's ability to suppress irrelevant/noisy tokens. Notably, SSA is a lightweight method which introduces less than 0.5% new parameters through a weight-sharing strategy and can be fine-tuned on existing LLMs. Extensive empirical evaluations demonstrate that SSA-equipped models achieve a noticeable and consistent accuracy improvement on language modeling benchmarks.
关键词
attention mechanismtransformersparsitylanguage modelarchitecture

评审与讨论

审稿意见
7

The authors introduce a modification to the widely-used self-attention mechanism, which modulates the queries and values in the attention computation with a per-token temperature parameter. They demonstrate that this allows the attention layer to model certain observations better for toy problems, and they also adapt several large language models to use this Selective Self-Attention (SSA), demonstrating that it moderately boosts overall perplexity & accuracy scores.

优点

The technical contribution represents an interesting and intuitive addition to standard attention. The addition of learnable query and value temperature clearly benefits the attention mechanism, both on illustrative toy examples and on real-world data. The mathematical exploration is comprehensive, leaving the reader with solid reasoning for why SSA solves several problems. The approach not only improves the performance of attention on standard modeling problems, but promises to improve on long-context tasks and positional encoding as well.

缺点

My main concern is the claim that the performance improvement on the LM task is due to "strategic integration of SSA" rather than an increase in parameter count. The improvements in perplexity/accuracy look statistically significant, but I don't think it's obvious that the improvement couldn't come from the increase in parameter count. There at least needs to be more argumentation to support this claim.

Another weakness is that the mathematical sections are pretty dense at times. Terms and mathematical propositions were introduced quickly without a lot of background. I found it difficult to understand the motivation for the lemma and propositions. I also think the denoising task, "feature imbalance", "effective weights", and "specificity" needed more introduction.

问题

Could we bolster the claim that SSA packs more benefit than just extra parameters by testing models that are scaled up with normal attention by equivalent amounts? Or can we instead argue this by looking at model scaling work that shows that a lot more parameter scaling is required to achieve similar drops in perplexity?

局限性

Attention is used extremely broadly, and the only real-world task used to evaluate SSA was the language modeling task. I would mention this as an important limitation on the claims of the paper.

作者回复

We thank the reviewer for their time, effort, and positive assessment.

W1: demonstrate the performance improvement on the LM task is due to "strategic integration of SSA" rather than an increase in parameter count. Q1: Could we bolster the claim that SSA packs more benefit than just extra parameters by testing models that are scaled up with normal attention by equivalent amounts?

In the general response (1) Making SSA parameter-efficient: We examine more efficient parameterization to push this overhead below 0.5%. In Table 1 in the common response, we tried several approaches by utilizing weight sharing and feature-based temperature selection. Those parameterizations also outperform the standard transformer. We also fine-tuned the model using LoRA as illustrated in the Table below. By controlling the dimension of LoRA, we made sure that LoRA and SSA contributed a comparable number of additional parameters. Although LoRA demonstrated performance similar to directly fine-tuned Pythia, SSA notably outperformed LoRA, creating a distinct gap.

ModelWikitext ppl ↓Lambada_std ppl ↓Lambada_openai ppl ↓Lambada_std acc ↑Lambada_openai acc ↑
Finetune
Pythia28.78151.62025.1730.3730.497
+SSA base (+5.26%)26.51447.94523.9560.3880.513
+LoRA(+5.63%)27.94449.85224.6530.3750.501
+SSA shared Wk/q/vW_{k/q/v} (+0.47%)26.68147.99624.1020.3830.504
+SSA feature-based (<0.01%)27.04848.90624.8140.3790.499

W2: The mathematical sections are pretty dense at times.

Thank you for this suggestion. We will make sure to add additional explanation and motivation in the text.

评论

Thank you for the thorough responses. I have increased the confidence of my review based on how you've addressed the concerns of all reviewers. I believe this work constitutes a meaningful contribution and should be accepted.

Reviewer Lp6e referenced some existing works which use "hard" sparse attention in order to increase efficiency. It would be important to include a brief discussion of these works in the related work section of the paper, so as to better situate and distinguish the goals of the present work from that body of research.

评论

We appreciate your positive feedback. We also agree regarding Reviewer Lp6e's comments. Indeed, rather than efficiency, our work aims to improve the accuracy and approximation capability of the model. We will incorporate a paragraph on sparse attention under the Related Work section to distinguish these. We will discuss earlier works, such as BigBird, LongFormer, Reformer, as well as the recent efficient sparse methods, such as [1,2] mentioned by Reviewer Lp6e.

审稿意见
5

This paper presents Selective Self-Attention (SSA), which introduces trainable temperature functions to adapt the contextual sparsity of attention weights. SSA uses a query temperature to adjust the sparsity for each query and its context position, and a value temperature to suppress irrelevant or noisy tokens. Experimental results show that SSA reduces perplexity in language modeling and improves optimization, with only a slight increase in parameters.

优点

  1. The paper is well-written and easy to follow.
  2. The motivation behind the method is clear and reasonable.

缺点

  1. The evaluation is limited to perplexity and accuracy in language modeling; conducting experiments on downstream tasks would provide a more comprehensive evaluation.
  2. The perplexity improvement over the original models is modest.
  3. The paper lacks comparisons with other selective or sparse attention methods from the existing literature.
  4. The claim that SSA benefits long text scenarios lacks supporting evidence. There are no studies on performance relative to text length.
  5. There is no detailed description of the datasets used, such as average text length or the number of training instances, which could help explain performance differences.

Minor comment: In the abstract, it should be Q and K rather than V and K.

问题

  1. How does SSA's performance vary with different parameter numbers? With more parameters, attention weight dilution might be alleviated, potentially weakening SSA's effect.
  2. There is no detailed description of the accuracy mentioned in Table 3. Could you clarify what this accuracy refers to?

局限性

Please refer to the weaknesses.

作者回复

We thank the reviewer for their time and effort. We hope that our response below addresses their concerns and we would be grateful to respond to further questions during the discussion period.

W1,2: Conducting experiments on downstream tasks, The perplexity improvement over the original models is modest.

As we discussed in Common Response: Evaluations on additional benchmarks Based on Reviewers NV9t and qMbv, we ran new evaluations. Table 2 provides results on 5 new benchmarks: Piqa, Hellaswag, Winogrande, Arc-E, and Arc-C. We also ran additional experiments with Llama3-8B and Pythia 410M as summarized in Table 4. For all settings, SSA exhibits clear and consistent improvements.

We also ran an experiment on the passkey retrieval task introduced in [Mohtashami and Jaggi NeurIPS’23]. The results are shown in Table 3. This is a synthetic task that measures a model’s ability to retrieve a simple passkey (i.e., a five-digit number) within a large amount of text. SSA leads to substantial improvement (from 56.9% to 74.4%). Intuitively, SSA is able to better solve this task by assigning different token-level temperatures to digits vs words.

W4: The claim that SSA benefits long text scenarios lacks supporting evidence. In the table below, our experiment was conducted on sequence lengths of 128 and 4096, where PmaxP_{max} represents the average maximum probability and H is the average entropy. From the results, one can observe for the vanilla Pythia models, an increase in sequence length from 128 to 4096 leads to a decrease in maximum probability of attention and an increase in entropy, implying a more dispersed attention. However, when incorporating SSA, the maximum probability for the 4096-length sequence is almost consistent with that of the 128-sequence length, and the entropy is significantly smaller in comparison to the vanilla models.

1284096
PmaxP_{\text{max}}
Pythia0.380.24
Pythia +SSA0.410.36
HH
Pythia2.834.86
Pythia +SSA2.763.02

W5: There is no detailed description of the datasets used.

We have provided all implementation details in Appendix A. We will add a remark in the revised version.

Q1: How does SSA's performance vary with different parameter numbers

We have conducted further experiments using larger models, specifically Llama3-8B and Pythia-410m. The details are shown in Common Response: Evaluations on larger LLMs: in the general response. In short, the approach continues to be beneficial for these models. We agree with the reviewer that more parameters could alleviate the attention map’s dilution. However, for this to happen, the embedding dimension of the K/Q/V should be large (ideally square weight matrices). However, while token embeddings grow with the model size, in practice, the number of attention heads also increases resulting in slower growth in K/Q/V’s embedding dimension and decaying aspect ratio for K/Q/V weight matrices. For instance, Llama3-8B has 32 attention heads with 4096 dim token embeddings resulting in only 128 dim K/Q/V embeddings whereas a smaller GPT2 has 768 embedding dim, 12 heads, and 64 K/Q/V dim.

Q2: There is no detailed description of the accuracy mentioned in Table 3. Could you clarify what this accuracy refers to?

LAMBADA introduces a word prediction task where the target item is challenging to guess (for English speakers) when only the sentence containing it is provided, but becomes easier with broader context. Here is a example from [1]

Context: “Yes, I thought I was going to lose the baby.” “I was scared too,” he stated, sincerity flooding his eyes. “You were?” “Yes, of course. Why do you even ask?” “This baby wasn’t exactly planned for.”

Target sentence: “Do you honestly think that I would want you to have a ____ ?”

Target word: miscarriage.

Accuracy is evaluated based on the model's ability to correctly predict the target word.

[1] Paperno, Denis, et al. "The LAMBADA dataset: Word prediction requiring a broad discourse context." arXiv preprint arXiv:1606.06031 (2016).

评论

Thanks again for your valuable suggestions. Before the discussion period is over, we wanted to kindly remind you that we have conducted comprehensive evaluations, which we believe, address your concerns. Specifically, our evals demonstrate that SSA facilitates noticeable improvements on multiple benchmarks (addressing W1,2,4) and we have provided two simple methods to reduce the computational/parameter cost of SSA (only <0.5% or <0.01% extra parameters) (addressing Q1). We appreciate your time.

评论

The new results address my main concerns. I have raised my score to 5.

审稿意见
5

The paper proposes token-aware and position-aware gating for query, key and value vectors in self-attention layers.

优点

  1. The position-aware gating is interesting, and, in theory, can alleviate dispersed attention and thus may lead to better length extrapolation.
  2. The inductive bias introduced in the paper can lead to faster convergence when finetuning a small language model with less than 200M parameters.

缺点

  1. The terminology and the positioning of the paper is misleading. The whole paper claims that they are doing qkv selection but there is actually no sparse attention involved. The author should look at some papers [1,2] focusing on real selection of the input for self-attention. Also, the paper claims that they are doing temperature scaling but actually their temperatures can be negative, so it is basically doing gating of the query, key and value vector because of its element-wise product nature. The authors also use the terminology, "soft sparsity", which is self-contradictory and unclear. I suggest a better terminology can be used as "spikiness of attention" instead of "sparsity of softmax".
  2. No actual wall-time efficiency benefits of the proposed method as a result of it essentially being a gating mechanism.
  3. The scale of the experiments are rather small with LMs below than 200M parameters.
  4. Lack of the experimental supports for the claim that position-aware gating can alleviate dispersed attention.

[1] Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention. https://openreview.net/forum?id=UINHuKeWUa

[2] Sparse Modular Activation for Efficient Sequence Modeling. https://openreview.net/forum?id=TfbzX6I14i

问题

  1. For token-aware "temperature scaling", why not using a simpler formulation of tanh() but 2*sigmoid()-1? Do you have some empirical results supporting such formulation?
  2. Do you have some theoretical explanations for why Key vector gating is not used in practice?

局限性

Yes.

作者回复

W1: The terminology and the positioning of the paper is misleading.

  • Sparsity ve Spikiness: We appreciate the feedback. We will clarify that our method is not about sparse approximation of the attention map and instead aims to control the “spikiness of attention” as the reviewer mentions. We will also clarify that “spikiness of attention” can be viewed as an “effective sparsity” which can be quantified through LL_\infty norm, L1/L2L_1/L_2 ratio, or inverse-entropy of the softmax map. This discussion will also better clarify what is meant by “contextual sparsity” throughout the paper and distinguish it from (hard) sparsity targeted in [1,2].
  • Gating vs Temperature: As we discuss in our related work section, temperature scaling (TS) is indeed a special case of gating. However, it is also a simple and powerful method by itself (use cases in uncertainty quantification, long tailed data, etc). Through TS, we are able to improve the performance of the model by scaling all entries of q/k/v using a single scalar (which is why we refer to it as temperature) and provide a clear explanation on how this provably improves the model’s expressivity. In practice, gating mechanisms are often more complex as we can gate each weight with a distinct learnable scalar. Based on the reviewer’s feedback, we have also ran additional experiments to see if such more complex gating can improve the performance. Specifically, we make τ\tau a learnable dd dimensional vector and elementwise multiply with k/q/v (as described in our Def 1). As shown in the table below, using a scalar temperature parameter is as good as learning individual scalars for each coordinate (called “SSA gating”). This also suggests that improvements are indeed arising from controlling “contextual sparsity/spikiness” rather than a more sophisticated gating mechanism.
ModelWikitext ppl↓Lambada_std ppl↓Lambada_openai ppl↓Lambada_std acc↑Lambada_openai acc↑
Finetune
Pythia28.78151.62025.1730.3730.497
+SSA(original)26.51447.94523.9560.3880.513
+SSA(gating)26.50948.06724.4600.3830.515

W2: No actual wall-time efficiency benefits.

Our work aims to improve expressivity rather than wall-time. That said, the benefit of token- and position-aware temperature strongly indicate that token- and position-aware sparsification should similarly help, which can unlock wall-time benefits through sparsity. For instance, Appendix D provides a theoretical connection between position-aware temperature scaling and sparsification and establishes a 1-1 map between the temperature choice and the associated sparsity level to maintain an identical spikiness level of the attention map.

W3: The scale of the experiments are rather small.

We appreciate your feedback. Our submission had GPT2, Pythia-160M, and Llama2-7B. In addition to these, we have conducted further experiments using larger and state-of-the-art models, namely, Llama3-8B and Pythia-410M. The details are provided in the Common Response: Evaluations on larger LLMs. SSA noticeably improves the perplexity and accuracy for both of these larger models, suggesting that its benefit persists across different model scales.

W4: Lack of experimental support for the claim that position-aware gating can alleviate dispersed attention.

In the table below, our experiment was conducted on sequence lengths of 128 and 4096, where PmaxP_{max} represents the average maximum probability and H is the average entropy. From the results, one can observe for the vanilla Pythia models, an increase in sequence length from 128 to 4096 leads to a decrease in maximum probability of attention and an increase in entropy, implying a more dispersed attention. However, when incorporating SSA, the maximum probability for the 4096-length sequence is almost consistent with that of the 128-sequence length, and the entropy is significantly smaller in comparison to the vanilla models.

1284096
PmaxP_{\text{max}}(spikiness)
Pythia0.380.24
Pythia+SSA0.410.36
HH (entropy)
Pythia2.834.86
Pythia+SSA2.763.02

Q1: why not using tanh() but 2*sigmoid()-1 Note that tanh(x)=2sigmoid(2x)1\tanh(x) = 2 \cdot \text{sigmoid}(2x) - 1, which is essentially the same function we use, differing only in scale. Since xx is trainable, the results should essentially be same.

Q2: why Key vector gating is not used The intuition from word embeddings of [Mikolov et al.’13] suggests that the similarity between a (key, query) pair should align with their cosine similarity. That is, cos(key1,query)>cos(key2,query)cos(key_1, query)>cos(key_2, query) should ideally imply that the queryquery attends more to key1key_1 compared to key2key_2. Assigning temperature/gating to scale the query vector does not change this order. However, if we assign distinct scalings to key1key_1 and key2key_2, we will end up with scenarios where attention scores are flipped i.e. τ1key1query<τ2key2query\tau_1 \cdot key_1^\top query<\tau_2\cdot key_2^\top query. In other words, our intuition is that assigning gating on keys will end up influencing their relative semantic similarities to queries (which could perhaps be better achieved via attention weights). This is in contrast to query-scaling which helps decouple the semantic similarity and contextual sparsity and the associated theoretical benefits (Sec 4.1.1 and Proposition 1).

We have also provided ablation experiments in Table 4 of the appendix of the submission. These show that SSA with query and value independently yields clear benefits (in line with our theory), whereas SSA on keys result in performance that is similar to the baseline vanilla attention layer (in line with the above intuition). We will revise and clarify the discussion of key-gating in Section 3.

We greatly appreciate the reviewer’s detailed and constructive feedback. We hope that this response has addressed your concerns. We would be happy to engage further during the discussion week.

评论

Dear Reviewer: Thanks again for your comprehensive feedback which will really help with better situating our manuscript. As the discussion period comes to an end, we would be happy to clarify if you have any further questions, especially regarding our response and supporting experiments.

评论

Thanks for the clarifiaction and the promising additional experimental results. I still think the usage of "selection" is misleading since there is already a popular work called selective attention [1], which leverages sparse attention for efficient contextualized selection.

[1] Selective Attention for Context-aware Neural Machine Translation (NAACL 2019)

评论

We acknowledge the reviewer's concern. To convey the contribution and method more clearly, we will revise the name of the method, potentially to "Temperature-scaled Attention". We would be happy to hear if the reviewer has other suggestions.

评论

I have raised my score from 4 to 5 given that the authors update the paper as discussed. Since Attention is already temperature-scaled, so I would suggest other names such as scalar-gated self-attention to emphasize more on its input-dependent nature.

评论

Thanks for the suggestion, we agree that scalar-gated would capture the basic idea well.

It's worth noting that there is relatively little theoretical understanding of how gating aids sequence modeling. We believe our theoretical results, well supported empirically, also represent a significant contribution on that front.

审稿意见
5

This work introduces Selective Self-Attention (SSA), an additional trainable module for standard transformer self-attention that applies temperature scaling to its three components: queries, keys, and values. By performing both token-based and position-based scaling to modulate the influence of different tokens and positions dynamically, SSA aims to improve control over contextual sparsity and relevance in self-attention mechanisms. The authors empirically demonstrate that SSA can be integrated into existing transformer architectures such as Pythia and LLaMA, adding approximately 5% more trainable parameters. In terms of training efficiency, SSA shows potential for accelerating training achieving similar performance with fewer tokens.

优点

  1. The paper focuses on an interesting area of understanding attention patterns and developing attention mechanisms to enhance contextual sparsity given a particular query.

  2. The paper is well-written and majority of sections are easy to follow.

缺点

  1. Computational overhead: the proposed selective self-attention modules lead to an approximate 5% increase in model parameters. The authors should address the additional computational cost at inference time, which could be substantial for large-scale deployments. In the case of larger models with billions of parameters, even a 5% increase could significantly impact both training and inference costs.

  2. Baseline model comparisons: The vanilla models evaluated in this work are relatively smaller in terms of parameter size, potentially making the comparison unfair.

  3. Need for pre-training: The SSA-based LLMs studied in the work are pre-trained from scratch. Can the authors provide a justification for this approach?

  4. Evaluation tasks: The evaluations are mainly on language modeling tasks (Wikitext and Lambada). Given the research motivation, I believe it is important to demonstrate SSA's effectiveness on other NLP tasks where the learned contextual sparsity can address practical issues such as reducing inference latency over longer sequences, KV cache reduction for improving prefill latency, etc.

问题

Please see weaknesses section.

局限性

Yes

作者回复

We thank the reviewer for their time and helpful feedback.

W1: Computational overhead

In our Common Response, we describe two strategies that reduce the number of parameter overhead below 0.5% while maintaining the benefits of SSA. These approaches are based on weight-sharing (0.47% overhead) and feature-based (<0.01% overhead). The former (re)uses the attention weights within the temperature module whereas the latter uses simple token-level statistics, such as their frequencies in training corpus. The former approach only stores 3 vectors (not matrices) per attention head whereas the latter is constant parameters per head. Both approaches also have negligible inference/latency overhead because we don’t require additional matrix multiplication. These methods only require vector dot-products (at the output layer of the temperature module) and elementwise scaling of matrices.

W2: Baseline model comparisons Beyond the discussion in our general response about employing multiple strategies to reduce the number of parameters while still achieving significant improvements, we also fine-tuned the model using LoRA as illustrated in the Table below. By controlling the dimension of LoRA, we made sure that LoRA and SSA contributed a comparable number of additional parameters. Although LoRA demonstrated performance similar to directly fine-tuned Pythia, SSA notably outperformed LoRA, creating a distinct gap. Beyond the discussion in our general response about employing multiple strategies to reduce the number of parameters while still achieving significant improvements, we also fine-tuned the model using LoRA as illustrated in the Table below. By controlling the dimension of LoRA, we made sure that LoRA and SSA contributed a comparable number of additional parameters. Although LoRA demonstrated performance similar to directly fine-tuned Pythia, SSA notably outperformed LoRA, creating a distinct gap.

ModelWikitext ppl ↓Lambada_std ppl ↓Lambada_openai ppl ↓Lambada_std acc ↑Lambada_openai acc ↑
Finetune
Pythia28.78151.62025.1730.3730.497
+SSA base(+5.26%)26.51447.94523.9560.3880.513
+LoRA(+5.63%)27.94449.85224.6530.3750.501
+SSA shared Wk/q/vW_{k/q/v} (+0.47%)26.68147.99624.1020.3830.504
+SSA feature-based (<0.01%)27.04848.90624.8140.3790.499

W3 :Need for pre-training

Our methodology includes both pre-training and fine-tuning to evaluate. SSA’s outperforms the baseline on both settings. For fine-tuning evaluation, we start by loading the official pre-trained model and then fine-tune it on the downstream tasks. For the pre-training evaluation, we train the model from scratch on the SlimPajama dataset. Subsequently, we evaluate the model on various downstream zero-shot tasks. This approach is widely used for measuring the performance and generalization capabilities of pretrained large language models across diverse tasks [2,4,15].

[2] Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, and Christopher Ré. Simple linear attention language models balance the recall-throughput tradeoff. arXiv preprint arXiv:2402.18668, 2024.

[4] Stella Biderman, Hailey Schoelkopf, Quentin Gregory Anthony, Herbie Bradley, Kyle O’Brien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, et al. Pythia: A suite for analyzing large language models across training and scaling. In International Conference on Machine Learning, pages 2397–2430. PMLR, 2023.

[15] Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.

W4: Evaluation tasks

Table 3 in the common response also examines the accuracy on the passkey retrieval task as defined in [Mohtashami and Jaggi NeurIPS’23]. This synthetic task measures a model’s ability to retrieve a simple passkey (a five-digit number) from a large amount of text. The details are shown in the general response. Our evaluation tasks primarily focus on language modeling tasks like Wikitext and Lambada. However, potential future applications of SSA, like enhancing linear attention by adjusting its 'spikiness', and improving the interpretability of the attention maps, were indeed discussed in the paper. It's important to note that our main focus was to propose and understand the theoretical effectiveness of SSA. While envisaging its wider scope in the realm of NLP wasn't central to this work, we certainly acknowledge it as an intriguing direction to explore in future studies.

评论

Thank you for addressing my concerns. I have increased my score to 5.

评论

Dear Reviewer,

Thanks for getting back and adjusting your rating. Let us know if you have any outstanding concerns.

We particularly appreciated your "computational overhead" feedback which led us to investigate efficient SSA. In the revised manuscript and GitHub release, we plan to make weight-shared SSA the default method as it retains expressivity while having essentially no architectural overhead.

Thanks

Authors

评论

Thanks again for your valuable suggestions. Before the discussion period is over, we wanted to kindly remind you that we have conducted comprehensive evaluations, which we believe, address your concerns. For instance, we provided two simple methods to reduce the computational cost of SSA (only 0.5% or 0.01% extra parameters). We appreciate your time.

作者回复

Common Response

We thank all reviewers for their detailed and constructive feedback. To recap, our method Selective Self-Attention (SSA) enhances the approximation capability of attention through a learnable temperature scaling (TS) parameter. We theoretically establish the benefit of choosing temperature in a position- and token-aware fashion and evaluations also corroborate the value of SSA. Most reviewers found the method to be intuitive and interesting. They also raised excellent questions on the practical viability. Here, we address their shared concerns and demonstrate how SSA can indeed benefit state-of-the-art language models with negligible compute and parameter overhead.

Our new experiments focus on three aspects:

  • (1) Improving parameter efficiency of SSA

  • (2) Demonstrating the benefit of SSA on additional benchmarks

  • (3) Benefit of SSA on larger LLMs such as Llama3-8B.

The full tables are in the pdf attachment. We also add partial tables in the text.

(1) Making SSA parameter-efficient: Our base approach adds ~5% parameters to the model. Reviewers NV9t, TP8j, and qMbv motivated us to examine more efficient parameterization to push this overhead below 0.5%. In Table 1, we tried several approaches by utilizing weight sharing and feature-based temperature selection.

Modelwikitext ppl↓lambada_std ppl↓lambada_openai ppl↓lambada_std acc↑lambada_openai acc↑
Fine-tune
Pythia28.78151.62025.1730.3730.497
+SSA base (+5.26%)26.51447.94523.9560.3880.513
+SSA shared Wk/q/vW_{k/q/v} (+0.47%)26.68147.99624.1020.3830.504
+SSA feature-based (<0.01%)27.04848.90624.8140.3790.499
Pretrain
Pythia27.94375.48734.4060.2790.351
+SSA base (+5.26%)26.91272.89133.1260.2940.360
+SSA shared Wk/q/vW_{k/q/v} (+0.47%)27.04673.07133.8140.2910.360
+SSA feature-based (<0.01%)27.28173.61433.7940.2870.359

Above the parameter-efficient approaches are:

  • Shared Wk/q/vW_{k/q/v}: We use the attention weights Wk/q/vW_{k/q/v} as the input MLP layer of the temperature module. This reduces the parameter count below 0.5% (10x fewer) while maintaining on-par performance as full SSA. Here, SSA only adds the output layer of the MLP, a vector with few parameters. Note that, this also makes the inference overhead negligible.

  • Feature-based: Rather than learning an MLP, we tune a single scalar based on token level statistics. Specifically, we set the temperature as a function of the occurrence frequency of the token within the corpus. This is inspired by the logit adjustment strategy of [Menon et al. ICLR’21] which sets the cross-entropy temperature as a function of class frequencies. This approach adds only constant parameters (<0.01% overhead).

Finally, all other parameterizations we tried (in the pdf) also outperform the standard transformer, underscoring the value of selectivity irrespective of its precise implementation.

(2) Evaluations on additional benchmarks: Based on Reviewers NV9t and qMbv, we ran new evaluations in Tables 2 and 3. Table 2 provides results on 5 new benchmarks. We find that adding SSA uniformly improves the accuracy and that parameter-efficient SSA methods perform on par with baseline SSA. The improvement is particularly substantial on Piqa, Hellaswag, Arc-E, and Arc-C benchmarks. Notably, feature-based SSA, which has essentially no overhead, also works well across the board.

ModelPiqa acc↑Hella acc_norm↑Winogrande acc↑Arc-E acc↑Arc-C acc_norm↑
Pythia0.6300.3180.4980.4010.219
+SSA base (+5.26%)0.6610.3590.5080.4260.230
+SSA feature-based (<0.01%)0.6560.3500.5000.4170.224

Table 3 examines the accuracy on the passkey retrieval task as defined in [Mohtashami and Jaggi NeurIPS’23]. This is a synthetic task to measure a model’s ability to retrieve a simple passkey (i.e., a five-digit number) within a large amount of text. SSA leads to substantial improvement (from 56.9% to 74.4%). Intuitively, SSA could better solve this task by assigning different token-level temperatures to digits vs words.

Pythia+SSA base (+5.26%)+SSA feature-based (<0.01%)
56.8974.4172.78

(3) Evaluations on larger LLMs: Following Reviewers NV9t and Lp6e, we ran additional experiments with Llama3-8B and Pythia 410M as summarized in Table 4. These demonstrate that SSA improves the accuracy of larger and more recent models as well (e.g. compared to our main evals on GPT2, Pythia-160M, Llama2).

ModelWikitext ppl↓Lambada_std ppl↓Lambada_openai ppl↓Lambada_std acc↑Lambada_openai acc↑
Fine tune
Llama3-8b12.41624.00213.9540.4810.684
Llama3-8b+SSA10.98222.67111.0520.4890.690
Pretrain
Pythia-410m22.51669.81432.7810.3210.371
Pythia-410m+SSA21.98069.04131.4580.3310.384

Summary: We hope that these new evaluations address the concerns of the reviewers and further corroborate the value of SSA. We look forward to engaging with the reviewers during the discussion period and thank them again for raising excellent points.

最终决定

This paper applies temperature scaling to K/Q/V, and by that improves control over sparsity and relevance, resulting in efficiency gains. Reviewers found the work of interest and clear, and the results promising and potentially useful. Concerns regarding the experimental setup, overhead, and some claims not being backed up by experiments have largely been addressed by the authors.