PaperHub
6.5
/10
Poster4 位审稿人
最低6最高8标准差0.9
6
6
8
6
4.3
置信度
正确性3.3
贡献度2.8
表达3.0
ICLR 2025

Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study

OpenReviewPDF
提交: 2024-09-28更新: 2025-03-02
TL;DR

Using the stick-breaking process formulation as a replacement for softmax attention.

摘要

The self-attention mechanism traditionally relies on the softmax operator, necessitating positional embeddings like RoPE, or position biases to account for token order. But current methods using still face length generalisation challenges. We investigate an alternative attention mechanism based on the stick-breaking process in larger scale settings. The method works as follows: For each token before the current, we determine a break point, which represents the proportion of the stick, the weight of the attention, to allocate to the current token. We repeat this on the remaining stick, until all tokens are allocated a weight, resulting in a sequence of attention weights. This process naturally incorporates recency bias, which has linguistic motivations for grammar parsing (Shen et al., 2017). We study the implications of replacing the conventional softmax-based attention mechanism with stick-breaking attention. We then discuss implementation of numerically stable stick-breaking attention and adapt Flash Attention to accommodate this mechanism. When used as a drop-in replacement for current softmax+RoPE attention systems, we find that stick-breaking attention performs competitively with current methods on length generalisation and downstream tasks. Stick-breaking also performs well at length generalisation, allowing a model trained with $2^{11}$ context window to perform well at $2^{14}$ with perplexity improvements.
关键词
transformerattentionstick-breakingsoftmaxlength extrapolation

评审与讨论

审稿意见
6

This paper presents a new type of attention mechanism called "stick-breaking attention". It is meant to have a bias for attending to recent positions. Instead of using dot-products between keys and queries as logits and renormalizing them using a softmax to form the attention weights, the dot-products are each passed through the logistic function and then used as the probabilities in a stick-breaking process. In other words, the model first decides how much to attend to the current token, then decides how much of the remainder to allocate to attending to the previous token, and so on. It is a generalization of geometric attention proposed in prior work, where the probabilities at each timestep can be different. The authors discuss details for implementing it efficiently. They test it against standard scaled dot-product attention on a simple synthetic task that advantages recency bias, and on a variety of natural language benchmarks. Stick-breaking attention gets better scores on most tasks.

优点

  1. The experiments include a variety of natural language benchmarks.
  2. The authors include a thorough comparison to similar prior work.
  3. The results as presented show that stick-breaking attention outperforms standard attention on a variety of benchmarks.
  4. Their new attention mechanism is reasonably fast, and further optimizations are possible.

缺点

  1. Their implementation relies on a very rough approximation of the exponential function to avoid overflow (Eq 5), so their method is implementing a function that is quite different from what is proposed. I think this can be avoided (see Questions), and I am curious to see if using an exact solution affects the results.
  2. To improve the readability of the paper, a more intuitive explanation of stick-breaking would be useful, for those who are not familiar with the term. For starters, why is it called stick-breaking?
  3. Although the results are encouraging in Section 5.3, it is not clear to me that there was much hyperparameter tuning on the baselines. I would like to see more discussion of this.
  4. I would like to see more discussion of why recency bias helps on the natural language benchmarks tested (see Questions).

问题

  1. The keys and queries do not have positional encodings, but the values still do, right? Does the lack of PEs on queries and keys reduce the model's expressivity (i.e., are there certain functions that it can no longer implement because of the loss of PEs)? 474: Stick breaking attention doesn't completely get rid of PEs, right?
  2. Although I intuitively understand the motivation for using stick-breaking, can you make a more rigorous case as to why stick-breaking has a recency bias? Is it not true that both softmax and stick-breaking attention can implement arbitrary attention patterns?
  3. Related to the above, did you verify that the model actually attends to more recent tokens? Does this happen in wikitext?
  4. 096: I'm not sure what this discussion is trying to say. Can you explain this more?
  5. 094: I think zi,j=zi,jz_{i,j} = z_{i,j} is a typo.
  6. Instead of using Eq 5, which is an inexact approximation, why not use the identity log(1+expx)=c+log(exp(c)+exp(xc))\log(1 + \exp x) = c + \log(\exp(-c) + \exp(x - c)), where c=max(0,x)c = \max(0, x)? This would solve the overflow problem while being an exact solution.
  7. 208: Doesn't this make your method much less parallelizable than standard attention? I think you can actually get the same parallel time complexity using a variant of parallel prefix sum -- have you considered using that?
  8. 304: How much of each sequence is devoted to the initial set of pairs, and repeated pairs?
  9. Table 2, Figure 6: Why do you suppose stick-breaking attention helps on these tasks? It even gets slightly lower perplexity on wikitext, which is not a task that is specifically advantageous for recency-biased models. Can you spend more time discussing what kind of capability each natural language benchmark represents? Why do we expect recency bias to help with them? Do the gains come from recency bias, or is it more about length generalization thanks to the use of relative PEs?
  10. Why does stick-breaking attention do better at length generalization on RULER if the retrieved data is not necessarily recent? I don't see why stick-breaking would be advantageous here.
  11. Why does stick-breaking attention eventually fail on RULER at longer lengths? What is the failure mode?
评论

Thank you for your interest and extensive list of questions!

Their implementation relies on a very rough approximation of the exponential function...

Instead of using Eq 5, which is an inexact approximation, why not use the identity log(1+exp(x))=c+log(exp(c)+exp(xc))\log⁡(1+\exp(⁡x))=c+\log⁡(\exp⁡(−c)+\exp⁡(x−c)), where c=max(0,x)c=\max(0,x)? This would solve the overflow problem while being an exact solution.

To answer both this question and the question of the approximation: In bfloat16, the result of using the approximation is actually equivalent to what the log(1+exp(x))\log(1 + \exp(x)) implementation will result in softplus is simply the identity in the positive regime (> 15), and 0 in the negative, often viewed as a ‘soft’ or ‘relaxed’ version of max(0,x)\max(0, x). So, despite being inexact analytically, the function is exactly the same in practice.

Specifically, consider when xx is large and positive, c=xc = x, and c+log(exp(c)+exp(xc))=x+log(0+1)=xc+\log⁡(\exp⁡(−c)+\exp⁡(x−c)) = x + \log(0 + 1) = x

With respect to your proposal: The goal of the kernel is speed. The max(0, x) op followed by 2 exponents inside the log will slow things down significantly, as this is an operation that is performed repeatedly inside the kernel for every tile.

To improve the readability of the paper ... For starters, why is it called stick-breaking?

We’ve referenced the papers in related work on the background of the stick-breaking process, and elected to not go too deeply into the reasons behind the name. We’ll add the following short description into the camera ready for completeness: "Consider breaking a stick of unit length at some point beta_1, keeping the segment of β1\beta_1 length, and keeping the segment of 1β11 - \beta_1. We repeat this process with the segment of 1β11 - \beta_1 length, breaking it at the proportion β2\beta_2, and so on. The length of the segments produced by this process is given by the stick-breaking formula.

Although the results are encouraging in Section 5.3, it is not clear to me that there was much hyperparameter tuning on the baselines. I would like to see more discussion of this.

The hyperparameters for the baseline were tuned with the same process as in the Power Scheduler method (reference), on a 1b parameter model. The stick-breaking model did not go through the same extensive hyperparameter tuning, but was rather a drop-in replacement and trained with the same hyperparameters. We will include a short write-up about the process with a reference to the power scheduler paper.

评论

The keys and queries do not have positional encodings, but the values still do, right?

RoPE position embeddings are not applied on the values, not just in our setting and baselines, but generally. For stick-breaking, we remove position embeddings from the query and keys, so stick-breaking does not use position embeddings at all.

As for expressivity: Empirically, we are not seeing degradations in the model performance across benchmarks. We are less familiar with the theoretical expressivity, but Reviewer Myhj may have some references (https://arxiv.org/abs/2310.13897) that could hint at the limitations of stick-breaking.

... can you make a more rigorous case as to why stick-breaking has a recency bias?... Related to the above, did you verify that the model actually attends to more recent tokens.... 096: I'm not sure what this discussion is trying to say. Can you explain this more?

To be clear, the claim we are making about recency is in a setting where, assuming all logits are equal, stick-breaking will give higher attention weight to the more recent logit, rather than assigning them the same weight (Figure 1.), and formally in Line 96, which is a formalised discussion of the recency bias of stick-breaking.

To verify 'attending to more recent tokens' in a realistic language modeling setting like wikitext is not straightforward. The previous token would be the most recent, and this would definitely result in poor language modelling capability if only the previous token were attended on by all heads. However, we verified this via a controlled toy setting as in the MQRAR toy problem, where the attention maps were much easier to analyse. There, we do find that in some cases, softmax has difficulty attending on the most recent occurrence. However, we believe with higher head dimensions, this capability will improve, with the caveat that longer sequence will require yet higher head dimensions.

208: Doesn't this make your method much less parallelizable than standard attention? ...

Optimisation and workload assignment techniques that were used in flash attention could be adapted for stick-breaking in similar ways. Prefix-sum could be used in a CUDA implementation setting, which we are currently looking at. Specifically, you may be concerned that the form of the gradient may mean serialised computation of the cumulative sum. There also could be a block-wise serial version of this that can be parallelised across threads, and then the resulting gradients accumulated at the end accordingly. While we've tried to optimise the kernel as much as possible in this initial version, there are many more avenues for improvement in future versions of the algorithm.

304: How much of each sequence is devoted to the initial set of pairs, and repeated pairs?

The number of initial sets of pairs is exactly the number of kv pairs for that experiment. However, after that, these assignments can repeat and new assignments can be made. The number of repeated pairs after the initial set of kv pairs is random, and so can vary from instance to instance.

评论

Table 2, Figure 6: Why do you suppose stick-breaking attention helps on these tasks? ..

For language modelling, our hypothesis is also our motivation for the stick-breaking process for attention. The observation from Physics of language models (as mentioned in Section 1): In a synthetic context-free grammar setting (a synthetic surrogate for natural language): "attention mechanisms attend to the "most adjacent non-terminal.” (line 47).

Relatedly, language models without further fine-tuning tend to exhibit failure to extrapolate to longer sequences, suggesting that even though RoPE is a relative position embedding, there are still failures to extrapolate to different relative positions to perform the above task. Baking the ‘most adjacent’ behaviour directly into the attention seemed like a reasonable thing to try to overcome this issue. In a causal left-to-right decoder setting, most adjacent would be the most recent match.

In the case of tasks like MMLU where few-shot examples are given, we hypothesise that in the standard softmax attention, the prior in-context learning examples could become distractors when answering the final question, resulting in lower scores.

At the larger scale, it is difficult to make sense of the attention maps produced by the attention layers, not to mention our model has 40 layers in total. However, we do agree it is an interesting phenomenon and will find ways to investigate this behaviour further.

Why does stick-breaking attention do better at length generalization on RULER if the retrieved data is not necessarily recent? I don't see why stick-breaking would be advantageous here.

Recency is a relative notion. Our results suggest that after the pre-training process, our model has learned to attend to the most recent sequence matching the needle lookup query. However, we hypothesise that the benefit of stick-breaking here may be the combination of the fact that it is attending on the most recent occurence of the 'needle' being searched for, and there was no overfitting on position embeddings seen in training.

Why does stick-breaking attention eventually fail on RULER at longer lengths? What is the failure mode?

This is a great question and something we are still investigating. Our current hypothesis is that given a sufficiently long sequence, keys that 'capture' the attention weight can occur randomly, despite not being a match for the needle lookup query. This may be a phenomenon that exists in softmax attention as well, and is what results in failure to extrapolate.

We hope we've sufficiently addressed your questions and concerns, and that you will reconsider your score. Thank you!

评论

Thank you for your responses.

After reading the other reviewers' comments, I am also concerned about the novelty of the proposed stick-breaking attention mechanism. Based on your response, my understanding now is that it was proposed as part of broader architectural changes in previous work but not studied in isolation. Can you discuss the delta between previous work and the version of stick-breaking attention you've tested in this work in more detail? I think your proposed title and abstract changes are a good step, but the relationship with past work should have been made clearer, and that is a definite minus in the current draft.

  1. Since there are no positional encodings, is it the case that stick-breaking attention transformers can't distinguish the same token type at different positions at all?

  2. Does your baseline softmax attention transformer not use any positional encodings? I think that baselines both with and without positional encodings should be compared against. How do you know the reason length generalization improves is not the removal of the positional encodings?

  3. As for analyzing attention on natural language, I think you could do the following: Use the attention weights to take the weighted average of the number of positions prior to the current position, aggregate this over all positions, heads, and layers for all sequences, and compare these two quantities for stick-breaking and softmax attention. Stick-breaking should be lower. Is that possible?

For language modelling, our hypothesis is also our motivation for the stick-breaking process for attention. The observation from Physics of language models (as mentioned in Section 1): In a synthetic context-free grammar setting (a synthetic surrogate for natural language): "attention mechanisms attend to the "most adjacent non-terminal.” (line 47).

I don't see how the comment from Allen-Zhu & Li (2023) suggests that a recency bias would help for CFGs or natural language. They seem to be simply observing a particular attention pattern for CFGs without implying that that attention pattern is inherently beneficial for the task at hand. Moreover, I think the point of the CFG experiments is to point out how transformers can model non-local hierarchical dependencies.

评论

Can you discuss the delta between previous work and the version of stick-breaking attention you've tested in this work in more detail?

  • Neural Data Router (Csordas et. al. 2021) studied the this attention mechanism, with a bidirectional modification, calling it Geometric attention, and studied only encoder Transformers on synthetic tasks.Furthermore, there were other architectural changes made on top of the base Transformer model in the study. While ablation studies were made, the interactions between the other changes (e.g. gating) and the attention mechanism change was not clear.
  • Moduleformer (Shen et. al. 2023) used Stick-breaking attention, with a slight difference of including the current token in the attention weights. The goal of the paper was to study the extensibility of a fully modularised transformer, and less so the effects of stick-breaking attention. While ablation studies were made, the interactions between the other changes (e.g. mutual information loss, mixture-of-attention) and the attention mechanism was not clear.

In our paper, we examined the inductive biases of stick-breaking attention (MQRAR experiment), the length generalisation capability (350M models experiments), and training at large-scale (1B & 3B models) and showing that length-generalisation capabilities still work at scale. The other contribution is also the Triton implementation of the kernel.

Since there are no positional encodings, is it the case that stick-breaking attention transformers can't distinguish the same token type at different positions at all?

Stick-breaking attention is not permutation equivariant: re-ordering tokens in the input to an attention layer will result in different outputs for each token, unlike the scenario where you do the same for softmax. This property allows the transformer to assign different weights to the same token type at different positions, which softmax attention can only do with the aid of position embeddings.

Does your baseline softmax attention transformer not use any positional encodings? I think that baselines both with and without positional encodings should be compared against.

This was tested in the 350M slimpajama training setting, see 'NoPE' (No Position Embedding; Kazemnejad et. al. 2023, https://arxiv.org/pdf/2305.19466). We will make it clearer in the text that this is what NoPE tests for.

How do you know the reason length generalization improves is not the removal of the positional encodings?

The original NoPE paper does pose a similar question, and suggests some length generalisation ability, but further testing in other papers (and this one) suggests that while NoPE can extrapolate somewhat, it is not the ideal choice (https://arxiv.org/pdf/2310.04418, https://arxiv.org/pdf/2402.09371).

Finally for clarity, the 1B & 3B models do use Softmax + RoPE, as stated in the text and caption for the table.

As for analyzing attention on natural language, I think you could do the following: Use the attention weights to take the weighted average of the number of positions prior to the current position, aggregate this over all positions, heads, and layers for all sequences, and compare these two quantities for stick-breaking and softmax attention. Stick-breaking should be lower. Is that possible?

This is a good idea, but is fairly involved due to the flash attention implementation of both stick-breaking and softmax. We will try to do this analysis in time for the discussion period.

I don't see how the comment from Allen-Zhu & Li (2023) suggests that a recency bias would help for CFGs or natural language...

In order to answer the original question you posed, we were trying to postulate why stick-breaking was beneficial for some tasks. Our original hypothesis as stated in the paper, was given the nature of context-free grammars, backed up by Allen-Zhu & Li's observation, was that the 'most adjacent' inductive bias was a good one to incorporate. Stick-breaking possesses this property in a decoder setting, and also has a recency bias.

However, there is some nuance to how 'recency' is characterised. Consider the following scenario:

                                      (current token)
relative pos: -8 -7 -6 -5 -4 -3 -2 -1 0
logits:        5  5 -2 -2 -2 -2 -2 -2

In this case, despite the recency of positions -1 to -6, stick-breaking would assign these positions zero (or close to zero) attention weights, due to the sigmoid activation. However, the earliest of the positions with a logit of 5,-7, would now get a higher weight than position -8. We could say positions -1 to -6 are not a match (low score), and the most recent of the matches (-8 and -7), was the one assigned the most weight.

Hopefully this makes it more apparent why we think this seems like a natural match for finding the 'most adjacent' non-terminal, and also does so without relying on correlations with relative position embeddings, as a position embedding method would require.

评论

We trace the weighted average of positions for our wikitext experiments for windows of 2048, and only look at the attention weights for the final timestep (pos=2047). Our reasoning was that this would give an even comparison: comparing weighted positions at an earlier position would bias results to lower relative positions. The positions were numbered from -2047 to -1, relative to the current position 0.

Our preliminary results are as follows:

SoftmaxStick-breaking
-428.1-221.9

These initial results that stick-breaking even on a language modelling task does attend to more recent tokens, perhaps using depth to achieve it's long-context capabilities.

However, here are several more details that should be accounted for that would require further consideration:

  • Stickbreaking does not sum to 1, which would result in a smaller (magnitude) weighted average. However at position 2047, most sum to > 0.95. We can account for this in several ways:
    • Leave it as is, and consider the remainder to be attending to pos 0, the current token
    • Assign the remaining attention weights to -2048, which means it attends to values further away, which is not really indicative of what is happening
    • Re-normalise weights to 1.

We may include these results once we've studied these more thoroughly, but we hope that you find it informative. If possible, we'd also hope that you could increase the score you've given our paper. Thank you!

评论

Thanks for the clarifications. I'm glad the contributions of the paper are better contextualized with past work. Thank you also for the preliminary results on the average attention span. I think the paper makes a nice contribution by studying stick-breaking attention in isolation and demonstrating good improvements on multiple benchmarks. On the other hand, the method itself is not novel, and I think the paper could make a better case as to why exactly it helps so much on natural language, so I would rather keep my score at 6.

审稿意见
6

This paper is about stick-breaking or geometric attention, an alternative to softmax that has a built-in bias towards more recent tokens. Given a query position j, each key/value position i < j computes a probability of "yes" or "no", and the attention weight on i is the probability that position j is "yes" but all intervening positions are "no".

Stick-breaking/geometric attention was introduced in previous work. It makes position embeddings unnecessary and can be computed in O(log n) parallel time. This paper presents an improved implementation, and shows that stick-breaking/geometric attention improves performance on the multi-query associative recall task and various benchmarks.

优点

This softmax alternative is simple and induces a bias towards attending to recent positions in a very natural way.

The experimental results all look strong.

缺点

As far as I can tell, stick-breaking attention is exactly the same as geometric attention (Csordas et al 2021), and stick-breaking was previously introduced by Shen et al (2023). Both papers are cited in the introduction, and the introduction concludes with an accurate list of the novel contributions of the paper. However,

  • The paper's short title "Stick-Breaking Attention" may give the impression that this is the first paper about stick-breaking attention.
  • The abstract does not mention previous work; on the contrary, it says "We propose an alternative attention mechanism."
  • The introduction, probably inadvertently, could be mis-read as saying that geometric attention only has one parameter ("Geometric attention, named after the Geometric distribution, which only has one parameter").

Eq. (2): Putting the remainder of the attention onto position j itself does not seem like the right choice. Probability (1βi,j)(1-\beta_{i,j}) is the probability of pushing the attention to the left of position i, so i(1βi,j)\prod_i (1-\beta_{i,j}) is the probability of pushing the attention all the way to the left. So it's not surprising that this turned out not to work well. Letting the attention weights sum to less than one (or equivalently, putting the remainder of the attention onto the zero vector) seems like the most sensible thing to do.

The Flash Attention-like implementation of stick-breaking attention is 20% slower than Flash Attention.

问题

What is the exact relationship of stick-breaking attention to geometric attention (Csordas et al 2021) and the stick-breaking attention of Shen et al (2023)?

Eq (1):

  • I'd suggest not using \cdot, as it might be misinterpreted as an inner product in this context.
  • The first summation only goes up to i-1, so a position cannot attend to itself, which is different from usual future masking. I didn't see this decision discussed; what's the reason for it?

Line 94: typo, should be zi,j=zi,jz_{i,j} = z_{i',j}?

Eq (4): I don't know if it matters, but you could use log1p(exp(z)) instead of log(1+exp(z)).

It may be interesting to node that while softmax with a temperature factor (α = softmax (z/T)) as T -> 0 approaches average hard attention, which is used in many theoretical studies of transformers (e.g., https://arxiv.org/abs/1901.03429, https://arxiv.org/abs/2106.16213), stick-breaking attention (β = σ(z/T)) as T -> 0 approaches strictly future-masked rightmost hard attention, which is the kind of attention studied by https://arxiv.org/abs/2310.13897.

伦理问题详情

As far as I can tell, stick-breaking attention is exactly the same as geometric attention (Csordas et al 2021), and stick-breaking was previously introduced by Shen et al (2023). Both papers are cited in the introduction, and the introduction concludes with an accurate list of the novel contributions of the paper.

However,

  • The paper's short title "Stick-Breaking Attention" may give the impression that this is the first paper about stick-breaking attention.
  • The abstract does not mention previous work; on the contrary, it says "We propose an alternative attention mechanism."
  • The introduction, probably inadvertently, could be mis-read as saying that geometric attention only has one parameter ("Geometric attention, named after the Geometric distribution, which only has one parameter").

EDITED: The authors have resolved the above concerns.

评论

Another suggestion would be to use logsigmoid(zi,j)+k=i+1j1logsigmoid(zk,j)\text{logsigmoid}(z_{i,j}) + \sum_{k=i+1}^{j-1} \text{logsigmoid}(-z_{k,j}), which should have a more numerically stable gradient assuming logsigmoid is implemented correctly.

评论

We've added the derivation of the log-space computation in the appendix.

评论

The paper's short title "Stick-Breaking Attention" may give the impression that this is the first paper about stick-breaking attention. The abstract does not mention previous work; on the contrary, it says "We propose an alternative attention mechanism."

We will amend the abstract and title appropriately for the camera-ready (See general response).

The introduction, probably inadvertently, could be mis-read as saying that geometric attention only has one parameter ("Geometric attention, named after the Geometric distribution, which only has one parameter").

We will modify this statement to the following: "Consider a parameter p(0,1)p \in (0,1) for the probability for success for a trial. The geometric distribution then gives the probability for which kk trials are needed for the first success: (1p)k1p(1 - p)^{k-1}p. But in stick-breaking attention and Geometric attention, each pp is assigned ..."

The original phrasing was not written with the intent to mislead, and we apologise for implying that Geometric attention only has one parameter. We take the flagging of concerns for ethics review seriously, and we hope this addresses all of your concerns on the ethics front.

Eq. (2): Putting the remainder of the attention onto position j itself does not seem like the right choice. Probability (1βi,j)(1−\beta_{i,j}) is the probability of pushing the attention to the left of position i, so i(1βi,j)\prod_i (1−\beta_{i,j}) is the probability of pushing the attention all the way to the left. So it's not surprising that this turned out not to work well. Letting the attention weights sum to less than one (or equivalently, putting the remainder of the attention onto the zero vector) seems like the most sensible thing to do.

We agree, and your hypothesis is a reasonable explanation for the poorer performance. We intend to remove the remainder mechanism for simplicity in the camera-ready version of the paper and train and re-evaluate the experiments without the remainder mechanism. We have already done so for the MQRAR experiment with identical results.

What is the exact relationship of stick-breaking attention to geometric attention (Csordas et al 2021) and the stick-breaking attention of Shen et al (2023)?

The main difference between both those proposals is the Shen et. al. 2023 only considers a unidirectional decoder-only attention, while Csordas et al. 2021 includes a mechanism for accounting for both “future” and “past” attention, in an encoder setup.

Eq (4): I don't know if it matters, but you could use log1p(exp(z)) instead of log(1+exp(z)).

An interesting detail about the kernel implementation is that we use exp2 and log2 in the attention weight calculations, which we believe has hardware-level optimisations. We multiply and divide by ln2\ln 2 in order to convert between base 2 and base ee logarithms.

To your point, though: 1) log1p would be an ee based logarithm, and 2) Triton does not expose this API (at least in the version we’re using). This is why we have used this implementation of softplus. Ultimately, one of the reasons we chose this particular implementation of softplus was because exp(z) will give nans at larger values of z.

It may be interesting to node that while softmax with a temperature factor...

Thank you for making us aware of this paper! We did wonder if there were any expressibility differences in the extremes with stick-breaking vs. softmax, and from reading the introduction, it does seem that this paper might give some insight to that.

Another suggestion would be to use logsigmoid...

Logsigmoid is a variant of softplus: logsigmoid(x)=softplus(x)\mathrm{logsigmoid}(x) = -\mathrm{softplus}(-x). So we are already manipulating certain properties of logsigmoid in order to minimise usage of exp and log for speedups.

We hope we have addressed all of your concerns and hope you can reconsider the ethics review flag.

审稿意见
8

The paper introduce an alternative attention that includes a form of positional embeddings, avoiding the need to add PE techniques such as RoPE. The method is based on the stick-breaking process, which means for each token in the context (keys), they assign a break point that represents the remaining stick to the current token. This means that if 2 tokens have equal logits, the one closer to the current token will receive more stick value, representing the positional order. Conducting this process through out the entire context results in a sequence of attention weights in lieu of traditional attention weights. The paper suggests this process incorporate recency bias in nature without PE. Experiments done show that the new method is better at length generalization in perplexity compared to attention+RoPE.

优点

  • The paper presents a novel addition to the zoo of "attention alternatives" by leveraging the stick breaking process, which performs both "attention" and positional embeddings intrinsically. The mechanism has a recency bias, meaning a token can prefer to allocate all its "energy" to few recent tokens, but it can also skip over and only attend to far-away tokens.
  • The paper also include details implementation in Triton for flash-attention style efficiency and speed-up optimization, which is hugely appreciated, especially when nowadays efficiency and scalability is valued more for massive training of LLMs.
  • The paper conducted diverse range of experiments, from throughput (only 20% slower than flash attention), perplexity and language modeling tasks (MMLU, ARC-c, hellaswag...), which all shows promising results.
  • The main driver of this work is to solve the length generalization challenge of LLMs, which show good results.

缺点

  • Method can be explained more clearly with diagrams, formulation should be defined more thoroughly.

问题

  • What is σ\sigma function? It is never defined. is it sigmoid ?
评论

Thank you for your kind comments!

What is σ\sigma function? It is never defined. is it sigmoid ?

Yes! We've overlooked adding this detail, and we will add it to the final paper.

Method can be explained more clearly with diagrams, formulation should be defined more thoroughly.

We will revise the implementation section, add more details to the formulation of stickbreaking attention, and improve on Figure 1 to better explain the method.

审稿意见
6

This paper does a thorough evaluation of stick-breaking attention (also known as geometric attention) for synthetic test tasks and natural language tasks, showing its effectiveness at capturing a preference for locality, as is natural and correct for natural language tasks. The paper examines a fast, numerically stable implementation of stick-breaking attention in Triton, which enables larger scale experiments.

优点

  • The paper provides thorough and useful experimental results on the performance of stick-breaking attention, providing good exploration of its effectiveness for artificial and natural language tasks, for length generalization, etc. The experimentation seem well thought through and well done.
  • The paper is generally clear and easy to read.
  • The paper is very honest about what it contributes and what it uses from prior work.
  • The paper provides useful and new empirical results on different forms of attention on different tasks.
  • The paper examines building an efficient, numerically stable Triton implementation of stick-breaking attention.
  • It's good to include comparisons to Gemma2-2B and Qwen1.5-4B so that people can easily see that the results are decent, even though the papers own results are the apples-to-apples comparisons.

缺点

  • The paper lacks originality in machine learning ideas. Stick-breaking attention has been previously explored by Yikang Shen (in multiple papers) and especially by Csordas et al. (2021), the latter under the name "Geometric attention". "Stick-breaking attention" is a better name for the model used, but the model is exactly the same as in these prior works, limiting the originality of this paper. The value is mainly in the more extensive experimentation, including showing performance on larger scale, standard natural language benchmarks.
  • This paper lacks somewhat in significance because of this. It does have some significance, since it is really good to see that these ideas really do give gains on standard NL tasks like ARC, Hellaswag or RACE, but the basic correctness of the idea had already been established.
  • The differences between many models in Table 2 are fairly small and nothing is said about the detailed validation of the results. Are these from single runs rather than averages from 3-5 runs with different random initialization? How much variance would there be here, how confident can we be that a result of 63.4 is better than 63.1 for Winogrande on Softmax vs. SB w/o remainder correction, for example?
  • How to produce an efficient, numerically stable Triton kernel basically follows the methods of FlashAttention and standard good practice (using log(1 + exp(x)), etc.)

问题

  • The differences between many models in Table 2 are fairly small and nothing is said about the detailed validation of the results. Are these from single runs rather than averages from 3-5 runs with different random initialization? How much variance would there be here, how confident can we be that a result of 63.4 is better than 63.1 for Winogrande on Softmax vs. SB w/o remainder correction, for example?
  • I took the text around lines 101-107 as suggesting that things should work better doing stick-breaking with remainder, but the actual results go the other way. It would be good to provide more understanding of why this harms rather than helping.
  • For the paper, you should explain lines 300-301. This isn't a question. I figured it out, but the text of the paper should explain MQRAR better for people who haven't seen it.
评论

We'd like to thank you for reviewing our work in detail.

The paper lacks originality in machine learning ideas. Stick-breaking attention has been previously explored by Yikang Shen (in multiple papers) and especially by Csordas et al. (2021), the latter under the name "Geometric attention". "Stick-breaking attention" is a better name for the model used, but the model is exactly the same as in these prior works, limiting the originality of this paper. The value is mainly in the more extensive experimentation, including showing performance on larger scale, standard natural language benchmarks.

This paper lacks somewhat in significance because of this. It does have some significance, since it is really good to see that these ideas really do give gains on standard NL tasks like ARC, Hellaswag or RACE, but the basic correctness of the idea had already been established.

We understand your concerns on the claims made in abstract and implied in the title. We will address this by modifying the abstract and the title (See General response).

The differences between many models in Table 2 are fairly small and nothing is said about the detailed validation of the results. Are these from single runs rather than averages from 3-5 runs with different random initialization? How much variance would there be here, how confident can we be that a result of 63.4 is better than 63.1 for Winogrande on Softmax vs. SB w/o remainder correction, for example?

These are legitimate concerns for benchmark results in many larger models, but having 3-5 runs for 1B and 3B models requires a lot of compute resources. For comparisons of variance across each benchmark, we could include further results from other open source models, to have a better sense of how much each benchmark fluctuates.

Nonetheless, we understand skepticism with regard to these benchmark results being significantly better, but the results could at minimum be considered equivalent to existing models / standard atttention.

We would also like to point that on top of the regular evaluation benchmarks, we also performed length extrapolation results on the RULER suite of needle-in-a-haystack style benchmarks. The results there are fairly significant across all the different variants of NIAH tasks. We will include the results of the full suite of RULER tasks in the supplementary materials.

This suggests stick-breaking affords length extrapolation out-of-the-box without degradation in performance compared to softmax.

How to produce an efficient, numerically stable Triton kernel basically follows the methods of FlashAttention and standard good practice (using log(1+exp(x))\log(1 + \exp(x)), etc.)

While the Triton kernel implementation takes inspiration from FlashAttention, the specific formulation of the log-space computation of stickbreaking (As stated in Eq. 3) is not as straightforward as suggested. Main considerations were reducing log and exp computations for speed, and we believe further optimisations could still be made.

For the paper, you should explain lines 300-301. This isn't a question. I figured it out, but the text of the paper should explain MQRAR better for people who haven't seen it.

Thank you for your feedback on this. We will modify the description of this experiment with the following paragraph: "Multi-query Associative Recall (MQAR) is a task that required a sequence model to ingest a sequence with key-value pairs: A 1 B 2... , where the alphabets are the keys, and numbers are the values. Later in the sequence, the model has to associate the token A and predict 1, for example. In our variant, Multi-query Repeated Associative Recall (MQRAR), while the key is retrieved, it will also be reassigned a value: A 1 ... A 2 A 3. In this case, on the second A, the model has to predict 1, the previous value assigned to A, while the third A will retrieve 2, and so on. We provide a full example below..."

We hope our response addresses your concerns satisfactorily, and you might increase your score. Thank you!

评论

Are there concerns and questions you have about our paper? We hope you find our responses satisfactory and can reconsider your score. Thank you!

评论

Thanks for your responses, which I agree with and accept. I'll move my score up to 6, but still believe that the questions of novelty and significance stop it being an 8.

评论

We understand. Thank you!

评论

Reviewers wgvD and MyhJ have cited issues of novelty in our work. While we referenced some of the prior work on stick-breaking as an attention mechanism, the reviewers are concerned that the title and the abstract did not appropriately reflect the position of our paper.

Prior work introduced stick-breaking attention, but introduced several other concepts and mechanisms without focusing on the contributions of replacing stick-breaking as the attention mechanism, and the inductive biases it introduces. Here, we solely replace the attention layer with stick-breaking attention, train, and evaluate the model on downstream tasks, making it a more focused evaluation on just the stick-breaking mechanism. The results on the larger scale experiments with its improvements over standard attention + RoPE, specifically on length extrapolation is a contribution in itself. Further, we’d also like to emphasise that the approach to scaling the method up is non-trivial. The Triton kernel implementation is important as it makes it amenable to large-scale training, whereas stick-breaking in its naively implemented form has not been adopted, in part due to its low throughput and high memory usage.

In sum, we think that stick-breaking attention has positive properties that make it better than softmax attention, and our paper contributes the evidence for that, along with an implementation for ease of use.

We agree with the reviewer's concerns, and propose the following changes to the title and abstract to address their issues:

Proposed title change:

Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study

Proposed modification to abstract:

We investigate an alternative attention mechanism based on the stick-breaking process in larger scale settings. The method works as follows: For each token before the current, we determine a break point βi,j\beta_{i,j}, which represents the proportion of the remaining stick to allocate to the current token...

评论

These seem like good appropriate changes to me.

AC 元评审

The paper explores a new type of attention mechanism for synthetic test tasks and natural language tasks. The proposed attention inherently incorporates positional embeddings, eliminating the need for additional positional encoding techniques like RoPE. The method assigns a breakpoint to each token, prioritizing tokens closer to the current one, thereby embedding a natural recency bias without explicit positional encodings. An efficient, numerically stable implementation in Triton enables large-scale experiments.

The authors follow the reviewers suggestion to make modifications to the abstract and title (Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study). AC appreciates the authors' effort in the rebuttal to dispel the reviewer's doubts.

审稿人讨论附加意见

The authors openly accepted the reviewers' comments regarding novelty and the acknowledgment and citation of prior work, and have made corrections accordingly.

最终决定

Accept (Poster)