PaperHub
7.3
/10
Poster4 位审稿人
最低4最高5标准差0.5
4
4
5
5
2.8
置信度
创新性3.3
质量3.3
清晰度2.3
重要性2.5
NeurIPS 2025

On the Robustness of Transformers against Context Hijacking for Linear Classification

OpenReviewPDF
提交: 2025-05-01更新: 2025-10-29

摘要

关键词
in-context learningtransformersrobustnessdeep learning theorylearning theory

评审与讨论

审稿意见
4

The paper theoretically explores context hijacking within an in-context linear classification problem, utilizing linear transformers. It designs context tokens as factually correct query-answer pairs where queries are similar to the final query but have opposite labels. The authors develop a theoretical analysis on the robustness of linear transformers as a function of model depth, training context lengths, and the number of hijacking context tokens. The key result is a formal equivalence between L-layer transformers and L-step gradient descent from general initialization. The authors derive optimal learning rates and show that deeper models perform finer-grained updates, yielding exponentially stronger robustness against hijacking attacks. A derived error bound explains why models like GPT-4 are more robust than GPT-2. Empirical results confirm the theory. This work offers a principled foundation linking model depth to robustness.

优缺点分析

Strengths

  1. This is the first rigorous framework analyzing context hijacking in transformers, reframing the problem via a clean equivalence to multi-step gradient descent.

  2. Demonstrates that deeper models achieve exponential robustness gains, offering concrete design guidance beyond vague depth heuristics.

  3. Theoretical predictions align precisely with experiments, explaining real-world trends like GPT-4’s superior robustness over GPT-2.

Weaknesses

  1. Analysis is limited to linear classification under isotropic Gaussian assumptions, which limits applicability to realistic NLP tasks.

  2. Focuses solely on linear attention-only transformers, omitting nonlinearities, layer norm, and MLPs essential in practice.

  3. The analysis relies on the assumption that transformers implement a gradient descent optimizer, and the authors note that more complicated meta-optimization algorithms could lead to different theoretical results. While this is a common approach in theoretical transformer works, it might simplify the true underlying mechanisms.

问题

1 The analysis assumes repeated hijacking examples that are linearly projected to lie on the decision boundary. Can the framework extend to more realistic hijacking attacks where context examples are diverse, semantically richer, or adversarially optimized? 2 The current analysis excludes MLPs, normalization, and nonlinearities. Do you expect your robustness characterization to hold for standard Transformers (e.g., GPT-style) if they approximately perform gradient-like steps? 3 Given the reliance on Gaussian-distributed features, how well do you expect your theoretical findings to transfer to real-world NLP tasks with structured and non-isotropic embeddings?

局限性

The theoretical framework is restricted to linear classification under isotropic Gaussian assumptions, which may not generalize well to real-world embeddings or tasks involving long-range semantics.

The analysis is limited to linear attention-only transformers without MLP, layer norm, or softmax, which reduces relevance to practical architectures.

最终评判理由

the responses addressed my questions well. I prefer to keep my current rating.

格式问题

None

作者回复

We appreciate your constructive questions and suggestions! We address them as follows:

Q1: Simplified linear tasks, Gaussian data, and transformers' architecture.

A1:

We acknowledge that there exist some gaps between the linear classification tasks, Gaussian data, and linear attention-only transformers adopted in our theoretical framework, and real-world complex models and complex tasks. However, we also point out that due to technical challenges, studying linear tasks with linear transformers on Gaussian data is a standard theoretical setting, and has been widely considered in many existing theoretical works regarding transformers, particularly in the in-context learning literature [1-8].

In addition, we would like to emphasize that the purpose of this work is to provide insights towards 'context-hijacking', a theoretically unexplored phenomenon. Proper mathematical simplification allows us to derive clear and precise quantitative characterizations of transformers' in-context learning capacities relative to their depths, rather than being biased by technical challenges. Specifically, our clear mathematical characterization effectively reveals that: when learning from context, shallow transformers are more 'aggressive', whereas deep transformers are more 'conservative'. While the exact mathematical form may not directly transfer to more complicated situations, our findings regarding depth-dependent learning strategies offer valuable theoretical insights that can guide future investigations into non-linear tasks, more complex data structures, or other aspects of multi-layer transformers' ICL. In fact, these theoretical conclusions are further empirically supported by our experimental results on more complex scenarios (see A2), guaranteeing the generalization capacity of our theoretical framework.

Q2: Limits applicability to realistic NLP tasks: can the framework extend to more realistic hijacking attacks where context examples are diverse, semantically richer or adversarially optimized?

A2: Theoretical extension.

Our theoretical framework can be extended to more general cases. For example, we can extend it to an out-of-distribution case, where the hijacking context examples and query follow different distributions (so it would be possible that context and query have opposite labels with similar embeddings). Notice that our conclusion that models' depth determines the optimal learning rates still holds (Theorem 3.3). It is natural when there exist significant distinctions between two distributions, deep transformers' 'conservative' learning strategies make them robust to hijacking.

Experiments on practical LLM architectures and real-world data distributions.

We strongly agree with your opinion and we realize the importance of generalizing our results to more realistic architectures. We conduct extensive supplementary experiments on LLMs of varying depths across diverse topic tasks to demonstrate the validity of our conclusions in real-world contexts. Our dataset is constructed as follows.

Dataset construction and settings.

  • First, we will design a fact retrieval problem. It is a direct question, such as "Of all the sports, Maria Sharapova is most professional in which one? The answer is". We want the model to predict the next token is "tennis".

  • Next, we will choose a topic that is factually correct. For the example above, we can choose the topic that "Maria Sharapova is not a professional in rugby".

  • Finally, we will add factually correct context prefixes of varying lengths before the question. Each sentence of this context prefix will describe the topic that has been determined from a different perspective and with different words. That is, paraphrase the hijacking context instead of repeating them, which makes context examples more diverse and semantically richer. In our example, these sentences could be "Maria Sharapova's tennis skills do not translate well to rugby", "The physical demands of rugby are not ones with which Maria Sharapova is familiar", etc. The model is then asked the same question. If the model predicts "tennis", then it is correct. If the model predicts "rugby", we call this "label flipping".

We design four datasets with different topics, including city, country, sports and language. And the number of samples in each dataset ranges from hundreds to thousands. We divide the context hijacking into eight different levels according to the length of the context prefix, from level 1 to level 8, which means the context has 10 to 80 sentences. We filter out questions that are too difficult based on the model's own capabilities and the difficulty of the questions, which means that the model could always correctly answer direct questions without hijacking context. We conduct experiments on Qwen2.5 base models of different sizes (depths) and corresponding instruction fine-tuned versions. The tables below show the label flipping rates of different models for different levels of context hijacking.

Experiment results

  1. City.
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.13200.20980.25980.30960.34870.40200.43370.4906
Qwen2.5-1.5B (28 Layers)0.02870.05890.10050.14710.17950.19500.21610.2411
Qwen2.5-3B (36 Layers)0.02300.04370.05870.07800.09220.10060.10670.1164
  1. Country.
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.40940.57690.59360.59770.61730.66030.65000.6692
Qwen2.5-1.5B (28 Layers)0.31250.39060.50000.51670.55470.54690.57810.5781
Qwen2.5-3B (36 Layers)0.17080.18680.18930.21980.22530.25270.25000.2555
  1. Sports:
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.74890.75830.76210.77380.77080.77880.79140.8006
Qwen2.5-1.5B (28 Layers)0.52550.58420.58560.58910.59870.59100.60200.6136
Qwen2.5-3B (36 Layers)0.11030.11770.13020.13360.13810.14030.13980.1484
  1. Language:
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.35090.50770.56660.61070.63770.64410.63830.6399
Qwen2.5-1.5B (28 Layers)0.07320.12790.17370.22840.24850.29220.28530.3013
Qwen2.5-3B (36 Layers)0.04350.07220.07400.09220.10430.10240.11160.1090

We can find that in practical LLMs, longer hijacking context will significantly increase the label flipping rate (leading to lower accuracy), while increasing the model depth can well alleviate this problem. The experiment results are consistent with our theoretical conclusions, indicating that our theoretical results can be generalized to deeper and larger LLMs in practice. And we are able to transfer our theoretical findings to practical NLP tasks with structured and non-isotropic embeddings. Additionally, we find that instruction fine-tuning (due to character limitations, please refer to the rebuttal to reviewer APEh) can improve the model's robustness to context hijacking in most cases, but the effect is not significant, which provides new insights for future work, such as adversarial optimization. This suggests that our work can provide insights into real-world problems.

We will provide all the experiment results and detailed experimental settings in our revised paper. We believe that our experiments on real-world tasks and architectures fully validate the applicability of our conclusions and hope that these results address your concerns.

Q3: Does the current conclusion still hold for more complicated architectures if they approximately perform gradient-like steps? How about more complicated meta- optimization algorithms?

A3:

Our robustness characterization remains valid for any scenario where transformers exhibit gradient-like learning behavior. The foundation of our framework lies in establishing depth-dependent optimal learning rates (Theorem 3.3). Crucially, this core conclusion persists under your assumptions that they approximately perform gradient-like steps. Therefore, our theoretical predictions about depth-mediated robustness continue to hold effectively.

Recent studies have proposed new perspectives on in-context learning, such as transformers' ability to approximate second-order optimization methods like Newton's method [9]. While the exact formulation may differ, we argue that our core theoretical insight in Theorem 3.3, the depth-dependent learning behavior, remains valid across different optimization paradigms. It is natural as long as the connection between the depth and the steps of optimizations still holds.

[1] Von Oswald, et al. Transformers learn in-context by gradient descent. ICML 2023.

[2] Ahn, et al. Transformers learn to implement preconditioned gradient descent for in-context learning. NeurIPS 2023.

[3] Zhang, et al. In-context learning of a linear transformer block: benefits of the mlp component and one-step gd initialization. NeurIPS 2024.

[4] Frei, et al. Trained transformer classifiers generalize and exhibit benign overfitting in-context. ICLR 2025.

[5] Zhang, et al. Trained transformers learn linear models in-context. JMLR 2024.

[6] Mahankali, et al. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. ICLR 2024.

[7] Chen, et al. How transformers utilize multi-head attention in in-context learning? a case study on sparse linear regression. NeurIPS 2024.

[8] Huang, et al. Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought. ICLR 2025.

[9] Fu, et al. Transformers Learn to Achieve Second-Order Convergence Rates for In-Context Linear Regression. NeurIPS 2024.

评论

I thank the authors for the rebuttal. I have read through the rebuttal. I think it makes sense overall. I will keep my score of 4.

评论

Thanks for your positive comment! Please do not hesitate to let us know if you have further concerns.

审稿意见
4

This paper investigates the phenomenon of context hijacking in Transformer-based large language models, where the presence of factually correct context can mislead a model into incorrect predictions. The authors formalize context hijacking as a theoretical linear classification problem and analyze it through a multi-layer linear transformer framework, drawing a connection between transformer layers and multi-step gradient descent optimization. A key theoretical result is that deeper transformers inherently offer greater robustness to context hijacking because they allow for finer-grained optimization, reducing the negative impact of misleading contexts. The paper rigorously derives optimal conditions for transformer initialization and learning rates as functions of model depth and context length, confirming these insights empirically through synthetic experiments. The primary contributions include a comprehensive theoretical framework for analyzing context hijacking, mathematical proofs of the relationship between transformer depth and robustness, and empirical validations that reinforce the theoretical findings.

优缺点分析

Strenghts

  • Novel Theoretical Framework: The paper introduces the first comprehensive theoretical analysis of context hijacking in transformers using a linear classification setup, which is both novel and practically relevant.

  • Solid Mathematical Grounding: The authors rigorously connect transformer layers with multi-step gradient descent. This explicit connection provides deep insights into the internal mechanics of transformers.

Weaknesses

  • Linear Simplification: While the linear classification setup enables clear theoretical insights, it limits the generalization of the results to complex, real-world nonlinear scenarios and tasks.

  • Assumption on Data Distribution: The analysis heavily relies on simplified assumptions such as isotropic Gaussian distributions for inputs and uniform distributions for certain parameters. Real-world contexts typically deviate from these idealized conditions.

  • Limited Empirical Scope: The numerical experiments primarily use synthetic data and simplified transformer structures (linear transformers). It remains unclear whether these findings generalize directly to practical, large-scale models.

问题

It is a well-established fact that linear transformers perform less well on in-context tasks [1]. So, why was this architecture chosen over the more traditional quadratic transformer?

[1] Aksenov et al, Linear Transformers with Learnable Kernel Functions are Better In-Context Models

局限性

yes

最终评判理由

Thank you to the authors for the clarifications and additional experiments. Some of my initial misunderstandings stemmed from not being deeply familiar with the area of mechanistic interpretability, and the authors' response has helped resolve these. The new experiments also shed light on the practical relevance of the problem studied in the paper. That said, I believe the paper could benefit from an even broader set of experiments, particularly those exploring more diverse architectures and practical scenarios. Therefore, I will raise my score to 4.

格式问题

作者回复

We thank the reviewer for detailed comments and suggestions.

Q1: Simplified assumptions on linear tasks and Gaussian data distributions.

A1:

We would like to first point out that studying linear tasks with Gaussian inputs is a standard setting and has been widely considered in many existing theoretical works regarding transformers, particularly in the in-context learning literature [1-8].

In addition, we emphasize that our work is the first to rigorously propose a theoretical framework - through formulating it into linear classification tasks - to investigate 'context hijacking', a theoretically unexplored phenomenon even under standard settings. With this simple yet intuitive mathematical formulation, we establish clear and precise quantitative characterizations in terms of the ICL capacities of transformers with respect to their depths. Specifically, our clear mathematical characterization effectively reveals that: when learning from context, shallow transformers are more 'aggressive', whereas deep transformers are more 'conservative'. While the exact mathematical form may not directly transfer to more complicated situations, our findings regarding depth-dependent learning strategies offer valuable theoretical insights that can guide future investigations into non-linear tasks, more complex data structures, or other aspects of multi-layer transformers' ICL. These theoretical conclusions are further supported by our experimental results on more complex scenarios (see A2).

Additionally, we emphasize that a key technical strength of our framework lies in its inherent ability to accommodate distributional shifts between: (1) context data and query data; (2) training data and test data. This characteristic makes our framework readily adaptable for studying both adversarial attacks and out-of-distribution performance. In summary, we believe this represents a good starting point.

Q2: Limited Empirical Scope: It is unclear whether these findings generalize directly to practical, large-scale models.

A2: Experiments on practical LLM architectures and real-world data distributions.

Although we experiment with nonlinear settings and the GPT-2 architecture (Appendix I.1 and I.3), encouraged by your suggestions, we realize the importance of generalizing our results to more realistic architectures. We conduct extensive supplementary experiments on LLMs of varying depths across diverse topic tasks to demonstrate the validity of our conclusions in real-world contexts. Our dataset is constructed as follows.

Dataset construction and settings.

  • First, we will design a fact retrieval problem. It is a direct question, such as "Of all the sports, Maria Sharapova is most professional in which one? The answer is". We want the model to predict the next token is "tennis".

  • Next, we will choose a topic that is factually correct. For the example above, we can choose the topic that "Maria Sharapova is not a professional in rugby".

  • Finally, we will add factually correct context prefixes of varying lengths before the question. Each sentence of this context prefix will describe the topic that has been determined from a different perspective and with different words. That is, paraphrase the hijacking context instead of repeating them. In our example, these sentences could be "Maria Sharapova's tennis skills do not translate well to rugby", "The physical demands of rugby are not ones with which Maria Sharapova is familiar", etc. The model is then asked the same question. If the model predicts "tennis", then it is correct. If the model predicts "rugby", we call this "label flipping".

We design four datasets with different topics, including city, country, sports and language. And the number of samples in each dataset ranges from hundreds to thousands. We divide the context hijacking into eight different levels according to the length of the context prefix, from level 1 to level 8, which means the context has 10 to 80 sentences. We filter out questions that are too difficult based on the model's own capabilities and the difficulty of the questions, which means that the model could always correctly answer direct questions without hijacking context. We conduct experiments on Qwen2.5 base models of different sizes (depths) and corresponding instruction fine-tuned versions. The tables below show the label flipping rates of different models for different levels of context hijacking.

Experiment results

  1. City.
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.13200.20980.25980.30960.34870.40200.43370.4906
Qwen2.5-1.5B (28 Layers)0.02870.05890.10050.14710.17950.19500.21610.2411
Qwen2.5-3B (36 Layers)0.02300.04370.05870.07800.09220.10060.10670.1164
  1. Country.
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.40940.57690.59360.59770.61730.66030.65000.6692
Qwen2.5-1.5B (28 Layers)0.31250.39060.50000.51670.55470.54690.57810.5781
Qwen2.5-3B (36 Layers)0.17080.18680.18930.21980.22530.25270.25000.2555
  1. Sports:
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.74890.75830.76210.77380.77080.77880.79140.8006
Qwen2.5-1.5B (28 Layers)0.52550.58420.58560.58910.59870.59100.60200.6136
Qwen2.5-3B (36 Layers)0.11030.11770.13020.13360.13810.14030.13980.1484
  1. Language:
ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.35090.50770.56660.61070.63770.64410.63830.6399
Qwen2.5-1.5B (28 Layers)0.07320.12790.17370.22840.24850.29220.28530.3013
Qwen2.5-3B (36 Layers)0.04350.07220.07400.09220.10430.10240.11160.1090

We can find that in practical LLMs, longer hijacking context will significantly increase the label flipping rate (leading to lower accuracy), while increasing the model depth can well alleviate this problem. The experiment results are consistent with our theoretical conclusions, indicating that our theoretical results can be generalized to deeper and larger LLMs in practice. Additionally, we find that instruction fine-tuning (due to character limitations, please refer to the rebuttal to reviewer APEh) can improve the model's robustness to context hijacking in most cases, but the effect is not significant, which provides new insights for future work, such as adversarial optimization. This suggests that our work can provide insights into real-world problems.

We will provide all the experiment results and detailed experimental settings in our revised paper. We believe that our experiments on real-world tasks and architectures fully validate the applicability of our conclusions and hope that these results address your concerns.

Q3: It is a well-established fact that linear transformers perform less well on in-context tasks. So, why was this architecture chosen over the more traditional quadratic transformer?

A3: Conceptual inconsistency

We want to clarify that the linear transformer [9] mentioned in the question is conceptually inconsistent with the linear transformer in our paper, or that the problems they aim to solve have different focuses.

As defined in the reference paper (Section 3.1), a linear transformer is a model that uses a kernel function to approximate the standard attention computation. The "linear" here means that the computational complexity of the approximation is linear in the sequence length. Its purpose is to improve computational efficiency, although this comes at the expense of some in-context learning performance.

In our definition (Section 2.2), a linear transformer is a model that removes the activation function from the standard transformer. The "linear" here means that the forward propagation of the model does not involve nonlinear calculations. This is a very common setup in transformer theory research [1-8, 10], aiming to create a mathematically tractable model that can be used to explain real-world problems from a theoretical perspective. We will discuss in detail the differences and connections between our work and the reference paper in the revision.

[1] Von Oswald, et al. Transformers learn in-context by gradient descent. ICML 2023.

[2] Ahn, et al. Transformers learn to implement preconditioned gradient descent for in-context learning. NeurIPS 2023.

[3] Zhang, et al. In-context learning of a linear transformer block: benefits of the mlp component and one-step gd initialization. NeurIPS 2024.

[4] Frei, et al. Trained transformer classifiers generalize and exhibit benign overfitting in-context. ICLR 2025.

[5] Zhang, et al. Trained transformers learn linear models in-context. JMLR 2024.

[6] Mahankali, et al. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. ICLR 2024.

[7] Chen, et al. How transformers utilize multi-head attention in in-context learning? a case study on sparse linear regression. NeurIPS 2024.

[8] Huang, et al. Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought. ICLR 2025.

[9] Aksenov, et al, Linear Transformers with Learnable Kernel Functions are Better In-Context Models. ACL 2024.

[10] Ren, et al. Learning and Transferring Sparse Contextual Bigrams with Linear Transformers. NeurIPS 2024.

评论

Thank you to the authors for the clarifications and the additional experiments. Most of my concerns have been addressed, and I will raise my score to 4.

评论

We're glad to see that our response address your concerns and appreciate your decision to increase you rating! Thank you once again for your efforts in helping improve the quality of our work.

审稿意见
5

I really enjoyed this paper. The paper discusses context hijacking. The terminology works as follows. If a context is adversarial, i.e., the prompt contains false information and adding the context to a prompt hurts the LM's ability to do the task, then that is, in a sense, expected. We do want LLMs to react to context. On the other hand, sometimes adding a correct statement to the context also derails the model. This is called context hijacking.

As I understand it, the idea is to come up with a theoretical model of context-hijkacing. There is a simple model that samples contexts and hijacked contexts. A number of propositions are proposed about building transformers that model his.

优缺点分析

This paper has a great premise and is well polished. I enjoyed reading it. The weaknesses are that it's quite difficult to get through. I am sure this is my fault, as I don't have the right background. The whole idea seems to be to cast the context hijacking problem as one of optimization. I was unable to follow most of the technical details.

问题

  • Is there any way to make the paper more accessible to readers who don't have a strong background in optimization? The answer to this could be no
  • Can the techniques be used for other, similar problems?

局限性

None

最终评判理由

I think this is a great paper and I would love to see it in the conference. I was happy with the authors' response.

格式问题

None

作者回复

Thank you for your recognition of our work and your constructive questions!

Q1: More accessible way to understand?

A1: Thanks for your suggestion. We will try our best to elaborate on our framework with detailed explanations, critical take-home messages, and without any mathematical formulas.

  1. Background of in-context learning (ICL). Empirical evidence shows that modern LLMs can adapt to diverse tasks solely based on the provided context data, without requiring any task-specific fine-tuning. For example, the same model can solve mathematical equations when given a math problem, yet seamlessly switch to translating text when presented with a sentence in another language, all without explicit retraining for either task. This is the so-called in-context learning capacity of transformers.

  2. A theoretical perspective to understand ICL: single attention layer conducts one-step gradient descent with context data. One theoretical explanation regarding this phenomenon is that the forward propagation of each layer in transformers is equivalent to conducting one-step gradient descent on context data ([1, 2, 3, 4] and Proposition 3.1 of our paper). You can imagine that every time a pre-trained transformer performs inference for a given query token, it internally constructs an 'implicit linear model' to generate the output. The initial parameters of this model are fixed and determined solely by the pre-trained transformer, independent of the context or query. However, while the initialization remains the same for all 'implicit linear models' corresponding to different tasks/context data/input tokens, when each layer of transformers forward propagates the input query, the parameters of the 'implicit linear model' will be updated one step based on context data. Intuitively speaking, every forward pass through one attention-layer is effectively a learning step, where the model "adapts" to the context data.

  3. Depth-dependent optimal learning rates of transformer models. Since the model’s depth determines the number of 'implicit gradient descent' steps available, the optimal learning rate achieving the training loss minimization would vary according to different depths of transformers. Specifically, the optimal learning rate of 'implicit gradient descent' on context data scales inversely with the depth of the transformer model (Theorem 3.3). This is intuitive: A shallow transformer must use a larger learning rate to ensure the 'implicit linear model' learns sufficiently from context data within limited steps. In contrast, a deep transformer, with more iterations available, opts for a smaller learning rate to enable fine-grained adjustments when near the optimum. In short, when learning from context data, shallow models are more 'aggressive', while deep models are more 'conservative'.

  4. Context-hijacking and robustness of transformers in terms of depth. For now, we can explain the phenomenon of 'context hijacking' and our findings of robustness against hijacking. Take Figure 1 as an example. We can conjecture that the context example 'Rafael Nadal is not good at playing basketball', and the query 'Rafael Nadal's best sport is' share similar embeddings in the representation space. When the model learns from contexts through 'implicit gradient descent', the embedding similarities cause the model to conflate their semantic meanings, leading to the incorrect answer 'basketball'. In some sense, context-hijacking occurs when the 'context data' and 'query' belong to different distributions: while their embeddings appear similar, their actual meanings are completely opposite. (In classification terms, this would be cases where inputs (xx) are similar but labels (yy) are contradictory.) Because the context data and query follow different underlying patterns, learning from such contexts becomes ineffective or even detrimental to model performance. This explains why deep transformers show greater robustness against hijacking: shallow models' aggressive learning approach makes them highly susceptible to interference from hijacking examples, whereas deep models employ a more cautious, incremental learning approach that allows for progressive refinement and better resistance to corrupting influences.

We hope that such an outline can help you better understand the logic of our theoretical framework, and we will incorporate these discussions into our revisions.

Q2: Can the techniques be used for other, similar problems?

A2: Yes, our framework can be extended to several similar problems.

  1. Since we employ the 2\ell_2 loss for classification tasks, a natural extension would be to adapt it to regression tasks.

  2. A key technical advantage of our framework is its inherent capacity to handle distributional shifts between: (1). context data and query data; (2). training data and test data. This makes our framework particularly suitable for investigating:

  • Adversarial attacks: A prevalent scenario where context data and query data exhibit distributional discrepancies. Specifically, the context data may contain factually incorrect information.

  • Out-of-distribution (O.O.D.) cases: A common situation where training context data and test context data follow different distributions. This typically occurs when applying a pretrained model to specific downstream tasks.

[1] Von Oswald, et al. Transformers learn in-context by gradient descent. ICML 2023.

[2] Ahn, et al. Transformers learn to implement preconditioned gradient descent for in-context learning. NeurIPS 2023.

[3] Zhang, et al. In-context learning of a linear transformer block: benefits of the mlp component and one-step gd initialization. NeurIPS 2024.

[4] Bai, et al. Transformers as statisticians: Provable in-context learning with in-context algorithm selection. NeurIPS 2023.

评论

We are truly grateful for your recognition of our work! As the discussion phase draws to a close, we wish to ensure that our responses have addressed your concerns. It would be our pleasure to provide any further clarification you may desire. We greatly appreciate your valuable time and feedback.

评论

Thanks a lot! This helps clear up some of my concerns. I am going to spend some more time with the paper tonight and I'll get back to you with some more questions should I have any. I hope the paper gets in!

审稿意见
5

This paper provides a theoretical framework and empirical support for analyzing the robustness of linear transformers against “context hijacking,” a phenomenon where factually correct yet misleading context examples can disrupt model predictions. The authors model in-context learning in linear transformers as multi-step gradient descent and derive how robustness scales with model depth, context length, and hijack token count. Their theory demonstrates that deeper models exhibit stronger robustness, which aligns with experimental results.

优缺点分析

Strength

  1. The work identifies and rigorously formalizes the problem of context hijacking, a subtle yet impactful threat to LLM robustness.

  2. The paper offers a clear and elegant connection between transformer depth and robustness via optimal multi-step gradient descent modeling.

  3. Numerical simulations back the theoretical predictions and explain why deeper transformers resist hijacking more effectively.

Weakness

  1. Limited generalization to nonlinear settings: Though nonlinear experiments are mentioned (Appendix I.3), their depth is limited (only showing a depth analysis). It remains unclear how well the conclusions carry over to architectures used in practice.

  2. Realism of assumptions: The study is limited to linear transformers and synthetic data. While this allows for tractable theory, it diverges from practical LLM architectures and real-world data distributions, potentially reducing the direct applicability of findings. More empirical experiments are encouraged to address this. Even using synthetic data, try to paraphrase the hijacking contexts instead of repeating them.

问题

N/A

局限性

yes

最终评判理由

I increased the score because the authors connect the paper better to real-world situations

格式问题

N/A

作者回复

Thank you for your informative feedback! We address your comments as follows:

Q1: It remains unclear how well the conclusions carry over to architectures used in practice. More empirical experiments are encouraged to address this.

A1:

Common settings for theoretical work.

From a theoretical perspective, it is a common method to abstract real-world problems into linear problems for analysis, because linear problems have sufficient representation power, supported by many previous works [1-9].

Experiments on practical LLM architectures and real-world data distributions.

Then, based on your suggestions, we realize the importance of verifying our conclusions in a more realistic architecture. We conduct extensive supplementary experiments on LLMs of varying depths across diverse topic tasks to demonstrate the validity of our conclusions in real-world contexts. Our dataset is constructed as follows.

Dataset construction and settings.

  • First, we will design a fact retrieval problem. It is a direct question, such as "Of all the sports, Maria Sharapova is most professional in which one? The answer is". We want the model to predict the next token is "tennis".

  • Next, we will choose a topic that is factually correct. For the example above, we can choose the topic that "Maria Sharapova is not a professional in rugby".

  • Finally, we will add factually correct context prefixes of varying lengths before the question. Each sentence of this context prefix will describe the topic that has been determined from a different perspective and with different words. That is, paraphrase the hijacking context instead of repeating them. In our example, these sentences could be "Rugby is not a sport that Maria Sharapova is adept at playing", "Maria Sharapova's tennis skills do not translate well to rugby", "The physical demands of rugby are not ones with which Maria Sharapova is familiar", etc. The model is then asked the same question. If the model predicts "tennis", then it is correct. If the model predicts "rugby", we call this "label flipping".

We design four datasets with different topics, including city, country, sports and language. And the number of samples in each dataset ranges from hundreds to thousands. We divide the context hijacking into eight different levels according to the length of the context prefix, from level 1 to level 8, which means the context has 10 to 80 sentences. We filter out questions that are too difficult based on the model's own capabilities and the difficulty of the questions, which means that the model could always correctly answer direct questions without hijacking context. We conduct experiments on Qwen2.5 base models of different sizes (depths) and corresponding instruction fine-tuned versions. The tables below show the label flipping rates of different models for different levels of context hijacking.

Experiment results

  1. City.

Qwen2.5-Base:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.13200.20980.25980.30960.34870.40200.43370.4906
Qwen2.5-1.5B (28 Layers)0.02870.05890.10050.14710.17950.19500.21610.2411
Qwen2.5-3B (36 Layers)0.02300.04370.05870.07800.09220.10060.10670.1164

Qwen2.5-Instruct:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B-Instruct (24 Layers)0.11810.16930.21710.26840.29820.34730.38360.4068
Qwen2.5-1.5B-Instruct (28 Layers)0.02210.03190.05430.08230.10250.09590.09830.0958
Qwen2.5-3B-Instruct (36 Layers)0.01800.03370.04530.05260.05700.06900.06990.0765
  1. Country.

Qwen2.5-Base:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.40940.57690.59360.59770.61730.66030.65000.6692
Qwen2.5-1.5B (28 Layers)0.31250.39060.50000.51670.55470.54690.57810.5781
Qwen2.5-3B (36 Layers)0.17080.18680.18930.21980.22530.25270.25000.2555

Qwen2.5-Instruct:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B-Instruct (24 Layers)0.66020.74110.76240.77160.78840.78840.80620.7827
Qwen2.5-1.5B-Instruct (28 Layers)0.48050.51320.51710.51880.52000.52860.52320.5329
Qwen2.5-3B-Instruct (36 Layers)0.12140.19470.26270.26810.27450.28070.28090.2872
  1. Sports:

Qwen2.5-Base:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.74890.75830.76210.77380.77080.77880.79140.8006
Qwen2.5-1.5B (28 Layers)0.52550.58420.58560.58910.59870.59100.60200.6136
Qwen2.5-3B (36 Layers)0.11030.11770.13020.13360.13810.14030.13980.1484

Qwen2.5-Instruct:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B-Instruct (24 Layers)0.54190.55770.59680.61220.65540.70350.71310.7115
Qwen2.5-1.5B-Instruct (28 Layers)0.25000.29410.36270.42650.44120.47960.48530.4900
Qwen2.5-3B-Instruct (36 Layers)0.14930.18060.19670.20000.21030.21610.22200.2418
  1. Language:

Qwen2.5-Base:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B (24 Layers)0.35090.50770.56660.61070.63770.64410.63830.6399
Qwen2.5-1.5B (28 Layers)0.07320.12790.17370.22840.24850.29220.28530.3013
Qwen2.5-3B (36 Layers)0.04350.07220.07400.09220.10430.10240.11160.1090

Qwen2.5-Instruct:

ModelLevel 1Level 2Level 3Level 4Level 5Level 6Level 7Level 8
Qwen2.5-0.5B-Instruct (24 Layers)0.30490.45100.52050.56110.57120.56390.56620.5733
Qwen2.5-1.5B-Instruct (28 Layers)0.06580.11420.12280.12870.13850.14100.13970.1452
Qwen2.5-3B-Instruct (36 Layers)0.01680.02880.03130.03530.03960.03830.04090.0375

We can find that in practical LLMs, longer hijacking context will significantly increase the label flipping rate (leading to lower accuracy), while increasing the model depth can well alleviate this problem. The experiment results are consistent with our theoretical conclusions, indicating that our theoretical results can be generalized to deeper and larger LLMs in practice. Additionally, we find that instruction fine-tuning can improve the model's robustness to context hijacking in most cases, but the effect is not significant, which provides new insights for future work, such as adversarial optimization. This suggests that our work can provide insights into real-world problems.

We will provide all the experiment results and detailed experimental settings in our revised paper. We believe that our experiments on real-world tasks and architectures fully validate the applicability of our conclusions and hope that these results address your concerns.

[1] Von Oswald, et al. Transformers learn in-context by gradient descent. ICML 2023.

[2] Ahn, et al. Transformers learn to implement preconditioned gradient descent for in-context learning. NeurIPS 2023.

[3] Zhang, et al. In-context learning of a linear transformer block: benefits of the mlp component and one-step gd initialization. NeurIPS 2024.

[4] Frei, et al. Trained transformer classifiers generalize and exhibit benign overfitting in-context. ICLR 2025.

[5] Zhang, et al. Trained transformers learn linear models in-context. JMLR 2024.

[6] Mahankali, et al. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. ICLR 2024.

[7] Chen, et al. How transformers utilize multi-head attention in in-context learning? a case study on sparse linear regression. NeurIPS 2024.

[8] Ren, et al. Learning and Transferring Sparse Contextual Bigrams with Linear Transformers. NeurIPS 2024.

[9] Huang, et al. Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought. ICLR 2025.

评论

Thanks for the detailed explanations and valuable new results. I think the new results are valuable to connect the conclusions in this paper to practice. I have increased my score to accept to reflect the quality improvement in the paper.

评论

We are very grateful for your recognition of our work! Your constructive suggestions are very valuable to our work.

评论

Dear reviewers, Please make sure you acknowledge that you have read the authors' rebuttal, which is mandatory this year. Also, please discuss with the authors if you have questions or comments on their rebuttal.

Thanks,

AC

最终决定

This paper theoretically investigates the behavior of Transformers under the so-called context hijacking for linear classification. The authors first establish the equivalence between the L-layer transformers and L-step gradient descent, then derive the optimal learning rates and initialization of the equivalent L-step gradient descent problem, which are shown to be a function of the number of layers and the length of training context. Based this result, the authors analyze the robustness of Transformers under the impact of the number of layers, training context length and test context length. This theorical framework explains a common observation of Transformers -- a deep transformer model is more robust to context hijacking than a shallow transformer model since a deep transformer model employs a more cautious, fine-grained learning approach that allows for progressive refinement and better resistance to corrupting influences. Overall, this is an interesting work. Some major concerns raised by the reviewers include

  1. lack of experiments on practical LLM architectures and real-world data
  2. simplified linear tasks, data distribution and linear transformer architecture.

These concerns have been successfully cleared by the authors in their rebuttal with added new experimental results on Qwen2.5 model and elaborated explanation on the necessity and usefulness of using a simplified mathematical treatment. The authors need to improvement the submission with new results and promised modifications in the revised version.