Does learning the right latent variables necessarily improve in-context learning?
We test whether explicit latent-variable inference leads to better performance in in-context learning and identify the benefits and shortcomings of such methods.
摘要
评审与讨论
This paper investigates whether explicitly inferring task-relevant latent variables improves in-context learning (ICL) in Transformer models. The authors introduce an explicit model that enforces structured inference of latent variables and compare it with a standard implicit model that learns ICL end-to-end without explicit latent variable inference.
The authors find that explicitly inferring task latents does not improve generalization. The explicit model effectively extracts latent variables but struggles to utilize them for robust predictions. This suggests that the challenge in ICL is not just learning task latents but also correctly leveraging them in downstream prediction.
给作者的问题
Thank you for your insightful work on the role of latent variable inference in in-context learning. I have a few questions regarding your methodology and findings:
Recovering True Latents in Implicit Models
-
In your experiments, you show that the explicit model successfully extracts the correct task latents. However, have you considered whether a standard Transformer (implicit model) might also recover similar latent representations, perhaps in the final-layer embeddings or through a linear probe analysis? This would help clarify whether the explicit model is truly unique in this regard. Dependence of Latents on Query Input
-
The explicit model prevents the query x_query from directly attending to the context. However, does this design fully isolate the necessity of learning the correct latent variables? If the true latent is inherently dependent on x_query (e.g., in tasks where the latent determines the relationship between query and context), wouldn’t this architectural constraint potentially hinder optimal inference? Failure to Leverage Learned Latents for Prediction
-
One of your key findings is that even when the explicit model correctly infers task latents, it does not generalize better. Could this be an issue with optimization dynamics (e.g., the choice of Adam or gradient descent) rather than an inherent failure to use the latents? Given that the last layer acts as a simple classifier, why does training another classifier on the same inferred latents yield better results? This suggests a mismatch between latent inference and prediction that is not fully explained.
论据与证据
I am unsure whether the chosen setting and the assumption that a single latent variable can effectively summarize all context information is the best way to explain how Transformers behave. In particular, enforcing this assumption by preventing x_query from attending to other context tokens may impose an artificial constraint that does not align with how Transformers naturally process information.
方法与评估标准
See previous section.
理论论述
No theory in this paper.
实验设计与分析
To some extend see "Claims And Evidence" section.
补充材料
I just skimmed through it, looking at details of data set and model in section C and B.
与现有文献的关系
NA
遗漏的重要参考文献
NA
其他优缺点
The paper explores a potentially interesting question regarding the role of explicit latent variable inference in in-context learning. However, in its current form, the study feels incomplete in several ways.
First, the assumption that a single latent variable can fully summarize context information independently of seems unnatural. Transformers process information dynamically, and restricting direct attention from to other context tokens may artificially constrain the model’s behavior rather than isolating a meaningful causal mechanism. This setup may not fully capture how in-context learning operates in standard architectures.
Additionally, while the negative result is valuable, the study does not sufficiently rule out alternative explanations. For example:
- Implicit models may also recover task-relevant latents in their final-layer representations, but this is not systematically tested.
- The failure of the explicit model to leverage inferred latents for prediction is not fully explained—is this due to architectural constraints, optimization dynamics, or a fundamental limitation of explicit inference?
- The bottleneck structure itself may limit information flow, rather than revealing a true failure of latent inference to improve generalization.
These concerns are further detailed in the "Questions for Authors" section. Addressing them could significantly strengthen the clarity and impact of the paper.
其他意见或建议
I find it difficult to pinpoint a clear takeaway from this paper. While it presents several interesting observations, the main message remains unclear.
We thank the reviewer for providing valuable and constructive feedback.
Single latent variable can fully summarize context information independently of seems unnatural
It is important to note that the dimensionality of the latent variable in explicit models is kept sufficiently large to encode the true latents for all the tasks considered (except GP). The independence assumption of latent variable and is standard across the field
- Deep learning models learn parameters of a neural network from training data to generalize to test observations . This is the setup for vision, natural language and tabular tasks where is kept independent of .
- The same assumption is also used in representation and unsupervised learning.
- Multiple works on ICL also claim this to be the underlying mechanism (Eq. 1 in [5,6])
Since most machine learning approaches train a model independent of the test set, they follow this independence assumption. In contrast, test-time training methods can be seen as not following this assumption [4].
restricting direct attention from x_{query} to other context tokens may artificially constrain the model’s behavior rather than isolating a meaningful causal mechanism
In all the tasks (except GP Regression), the correct causal mechanism is to infer the true latents solely from context and then leverage it for prediction. This indeed constrains the hypothesis class and theoretical results suggest that reducing the hypothesis class leads to tighter bounds as long as the true solution is realizable [1], which it is as the independence assumption is satisfied in the true data generating distribuion, i.e. .
Implicit models may also recover task-relevant latents
Indeed they can, but the search space is larger. Given tokens and layers, it is not clear which combination of representations should you decode latents from. Different places to probe can provide vastly different inferences; whereas in explicit model there is only one natural place to probe as it is precisely trained to infer the latents. We refer the reviewer to Figures 4(b) and 9 where we investigate this with probes in both implicit and explicit models and show that the counterfactual performance is worse in implicit models, showing that either the latents are not sufficiently encoded or just difficult to find.
failure of the explicit model to leverage inferred latents for prediction
We conjecture that this is due to optimization dynamics or lack of inductive biases since our prediction model is expressive enough to represent a linear mapping, but is not able to do so when combined with the problem of inferring latents, in OOD settings. Note that this is similar to standard studies in OOD generalization that highlight failure cases of deep learning methods.
bottleneck structure itself may limit information flow, rather than revealing a true failure of latent inference
Most works on inductive biases limit information flow and can aid generalization if this is reflected in the data too, eg. modular systems [2] limit information flow by blocking non-activated experts, information bottlenecks through KL bounds [3], etc. In our work, we consider tasks where inferring the correct latents should block additional information flow from the context to the query. Could the reviewer clarify if they meant something else by true failure of latent inference?
why does training another classifier on the same inferred latents yield better results?
In one case we keep the last layer fixed and learn just the latent inference (known prediction) while in the other, we learn both latent inference and prediction jointly. Our experiments demonstrate that the latter leads to sub-optimalities, i.e. joint training of prediction parameters and latent variable inference causes problems.
We hope that our response has resolved the reviewer's concerns and would be happy to provide further clarifications.
[1] Shalev-Shwartz, Shai, and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
[2] Andreas, Jacob, et al. "Neural module networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
[3] Tishby, Naftali, and Noga Zaslavsky. "Deep learning and the information bottleneck principle." 2015 ieee information theory workshop (itw). Ieee, 2015.
[4] Sun, Yu, et al. "Test-time training with self-supervision for generalization under distribution shifts." International conference on machine learning. PMLR, 2020.
[5] Müller, Samuel, et al. "Transformers can do bayesian inference." arXiv preprint arXiv:2112.10510 (2021).
[6] Han, Seungwook, et al. "Emergence of Abstractions: Concept Encoding and Decoding Mechanism for In-Context Learning in Transformers." arXiv preprint arXiv:2412.12276 (2024).
Thank you for the clarification.
Regarding the x_query and the explicit model: I understand the setup you're referring to. My question (in a simplified setting to make my point) is this — suppose the optimal algorithm is k-nearest neighbors. If we don’t allow x_query to be compared to other in-context examples, doesn’t that make it harder for the model to recover the k-NN algorithm? The reason I think this matters is that this kind of comparison between tokens seems like a central capability that enables transformers to perform so well.
Another point — and this might be due to my misunderstanding — is about the "true latent variable." If the model is implementing some algorithm in its forward pass, like gradient descent (as in Von Oswald et al., 2023) or others, then isn’t the ICL capability not really about learning the correct latent variables? Could you clarify on this?
In summary, the main message of the paper is still unclear to me. The only concrete takeaway I see is that ICL is not due to learning some latent variable plus a classifier on top. But out of the many possible explanations for why ICL works, this paper seems to rule out just one specific type. Therefor, I’m keeping my score as is.
We thank the reviewer for engaging in discussion and hope that our response alleviates their concerns.
suppose the optimal algorithm is k-nearest neighbors. If we don’t allow x_query to be compared to other in-context examples, doesn’t that make it harder for the model to recover the k-NN algorithm?
We completely agree that if the optimal prediction was k-NN (or any non-parametric model), then the comparison between tokens is needed and hence an implicit model might do better (eg. GP regression). Identically, if the true underlying model was parametric, the optimal algorithm could be to infer the parameters and allowing comparison between tokens can make learning this solution harder. Thus, the two different choices above incentivize different solution structures which one is preferable depends on the task.
Importantly, a very broad class of problems do have an underlying true model that can be described with some latent variables, e.g.
- inferring the latent rules of games like chess,
- inferring semantics like objects and relationships from visual scenes,
- image classification there is a true mapping that defines whether something is a cat or not,
- properties of molecules there are principles of science (chemistry and physics) that govern the properties
- properties of galaxies through underlying parameters (often the target in simulation based inference)
and so on. This is the class of problems that we focus on for tasks with true underlying latent variables, is explicitly inferring those latents useful for better prediction. We do not claim that all tasks require or are modeled through some underlying latents, but a lot of them are and those form the basis of our hypothesis and analysis.
Even further, if the optimal algorithm is k-NN, then the explicit model could also infer the same solution by learning the voronoi tessellation (https://en.wikipedia.org/wiki/Voronoi_diagram) corresponding to the observations.
If the model is implementing some algorithm in its forward pass, like gradient descent (as in Von Oswald et al., 2023) or others, then isn’t the ICL capability not really about learning the correct latent variables?
We thank the reviewer for bringing this point up and realize that there may have been a potential misunderstanding. We first note that [1] does model the latent variables ( in their notation) but it does so in a complex and distributed manner by modeling this latent variable inference (through gradient descent!) and prediction ( in their notation) jointly with an implicit model.
Note that if ICL is implementing some algorithm (gradient descent, Bayesian posterior estimation, etc.), then there needs to be some object (latent variable) that is being optimized or whose posterior is being inferred. Essentially, [1] shows that the implicit model in linear transformers for linear regression can be seen as composing context aggregation and prediction (which is by design in explicit models but can be inferred separately in this specific case of implicit models) such that the context aggregator infers the latents through gradient descent.
We note that, in contrast, we do not make any assumptions about what algorithm a context aggregator can use to infer the latents, we only test whether they do. Since we do not make strong assumptions about the task, it is not easy (maybe even impossible!) to break down the implicit model into context aggregation and prediction, as [1] does for the specific case of linear transformers and linear regression.
We thank the reviewer for their insight and hope our response answers their questions. We would greatly appreciate an increase in rating if the concerns have been addressed.
[1] Von Oswald, Johannes, et al. "Transformers learn in-context by gradient descent." International Conference on Machine Learning. PMLR, 2023.
This paper delves into the mechanism for Transformers to do In-Context Learning (ICL). A common belief is that TF do CL through some statistical shortcuts and hence can not generalize well in OOD tasks. The authors test this hypothesis by minimally modifying the architecture to encourage the model to explicitly aggregate the information in the context to learn the task representation before performing the final prediction to the label fo the query. By modifying the architecture, the authors are able to compare a usual TF (the so-called implicit model) and the modified one (the explicit model). They conduct a series of experiments on various ICL tasks, including linear regression, non-linear regression and classification, and some reasoning tasks. They showed that 1. The explicit model does not outperform the implicit model in various tasks and in both ID and OOD cases. 2. They prove that the explicit model can learn the task representation very well, so the reason why they are poor in OOD cases is that they did not learn the final prediction part well. Replacing the final prediction function with an oracle leads to a much better performance.
In general, the goal, experiments, and results are pretty well presented in this paper. The experiments are sufficient and clear. I like the style of this paper. Although they focused on a small question, they studied it deeply.
给作者的问题
/
论据与证据
I have one suggestion.
- The OOD task for linear/nonlinear regression/classification only relies on sampling the query input from a distribution with a larger variance. This is not very sufficient since this type of 'query shift' can be well tolerated even by a single layer of the linear self-attention [1] model. So I will suggest trying more complex distribution shifts and see what happens for both models.
[1]. Trained Transformers Learn Linear Models In-Context.
方法与评估标准
/
理论论述
/
实验设计与分析
/
补充材料
/
与现有文献的关系
/
遗漏的重要参考文献
/
其他优缺点
/
其他意见或建议
/
伦理审查问题
/
We thank the reviewer for their feedback and appreciate that they found the paper to be a clear, well written, in-depth analysis of the subject.
To alleviate their concerns regarding OOD generalization, we refer them to Figures 2(c) and 5 where we test for compositional generalization instead of just shifting the query. In these tasks, novel combinations of underlying latents are provided during inference as opposed to solely changing the query (see Training and Evaluation on Page 5).
Finally, we provide additional analysis where we test for OOD generalization in linear regression and classification by changing the distribution of the underlying weight vectors to be sampled from a normal distribution with larger variance. Our results, provided in Table 1 here (https://anonymous.4open.science/r/explicit-implicit-rebuttal-B263/explicit-implicit-rebuttal.pdf), indicate that it is not the case that explicit models are able to generalize better than implicit ones in such OOD scenarios. Even further, we see that while known prediction is a bit better, it still lags behind the implicit models because this setting is OOD for the context aggregator while maintaining in-distribution .
We hope that our response has resolved the reviewers' concerns and would be happy to provide further clarifications.
This paper notes that when we do in-context learning, it is likely that the network is, in some sense, learning about the structure of the task. This paper considers task spaces that are explicitly low-dimensional, such as linear regression, where you can use in-context learning to give information about the linear regression. To encourage the network to use this low-dimensional structure, they add a bottleneck to the transformer. However, they find that this does not improve performance. Thus, the paper is in many ways presenting a "negative result".
给作者的问题
N/A
论据与证据
Yes.
方法与评估标准
Yes.
理论论述
N/A (no such theoretical claims).
实验设计与分析
I am satisfied that the evidence they present supports their claims.
补充材料
No.
与现有文献的关系
Good connections drawn in the Introduction. Related work is in the Appendix, which I don't really mind, but some might object to.
遗漏的重要参考文献
None to my knowledge.
其他优缺点
My central issue with this paper is that it is --- ultimately --- a negative result.
That can be fine if the negative result is sufficiently interesting, surprising and convincingly argued.
However, I don't think that's the case here. In particular, while the authors clearly expected to improve in-context learning by introducing a bottleneck, I believe that this would be a very rare view in the field. That's basically because transformer based LLMs work really, really well. In particular, they work really well at in-context learning, and they work really well at a broader array of tasks that share character with in-context learning (even completing the next token using a pre-trained model requires bringing in lots of information from the context). Now, perhaps the essential component of transformers is self-attention. And of course, self-attention is the opposite of a bottleneck, as it allows each token to attend to any previous tokens. So I'm really not sure who would expect introducing a bottleneck to improve performance. Indeed, whenever we introduce bottlenecks of any form into attention (sliding window, quantised KVs etc.) you get worse performance.
Additionally, the interpretability results aren't that interesting as:
- They're restricted to a network with bottlenecks that no one (presumably, not even the authors) would use in practice.
- Their experiments are limited to settings with known latent variables (interpretability is most interesting when we don't know the latent variables).
Now, if you buy that the result is expected, then I'm really not sure that the result is suitable for ICML. I would instead recommend that the authors consider a venue such as TMLR, which has two key criteria for acceptance: "Are the claims made in the submission supported by accurate, convincing and clear evidence?" "Would some individuals in TMLR's audience be interested in the findings of this paper?" The paper does clearly meet these thresholds. But ICML requires something more (which was the entire point of setting up TMLR).
其他意见或建议
N/A
We thank the reviewer for their feedback, but strongly disagree that our paper is a negative result that is “not sufficiently interesting, surprising or convincingly argued”. Besides the fact that several reviewers (1Akw, 1soF, TjQ5) found our motivation (detailed below) and study interesting and convincingly argued, ICML has a track record of accepting negative results [1].
improve in-context learning by introducing a bottleneck, … very rare view in the field.
We believe the reviewer has misunderstood the motivation of our work. Our goal is not to provide a new architecture but to investigate the hypothesis that Transformers suffer from sub-optimal performance primarily due to insufficient latent variable inference. To systematically test this hypothesis - which is an open debate (Lines 19-52 RHS provide ample citations on both sides) - we use an architecture biased towards latent variable inference using a bottleneck (Lines 52-89 LHS). If Transformers' performance is linked to explicitly inferring task latents, then biasing the model towards this solution ought to improve OOD generalization. If it doesn't, we can conclude that explicit latent variable inference isn't sufficient to improve ICL. We do not suggest that our architectures, or bottlenecks, are the answer; they are simply minimal interventions to test the importance of correct latent variable inference in ICL.
Even though our goal is not to introduce new architectures, it is not a "rare view" in ML to improve generalization via bottlenecks for which ample evidence exists
- MoE architectures [2] use only a subset of parameters
- Perceiver models [3] introduce bottlenecks through learned latent variables
- Retrieval and memory augmented methods [4] introduce bottlenecks by selecting a subset of data for context
- Information bottleneck [5] are well studied for improving generalization
- Parametric assumptions (eg. training a neural network on data and then throwing away the data) introduces a bottleneck as opposed to methods like kNNs which retain entire datasets (note the bottleneck here is the trained model, which has strictly less information than the data it was trained on).
In our own study, explicit models with known prediction function — which is more bottlenecked than both the explicit and implicit model — outperforms both on various tasks. In light of these works, we believe that the story is more nuanced than introducing bottlenecks (or inductive biases) reduces performance.
transformer based LLMs work really well.
We agree with the reviewer that they do, but this should not disincentivize research into improving or analyzing them.
limited to settings with known latent variables
Our goal was never to provide a method for interpretability. To investigate the hypothesis that latent variable inference is not the problem, we need tasks where we can evaluate whether we are inferring the latents well. We rely on counterfactual predictions, a commonly used interpretability tool, as the metric to evaluate the extent to which the models infer task latents (hence the requirement for ground-truth latents). Our analysis highlights the difficulty of inferring task latents from implicit models, even though they perform well. This study thus contributes to a body of empirical evidence that allows us to conclude that improving task latent inference by itself is not the key to improved ICL generalization.
interpretability is most interesting when we don't know the latent variables.
We disagree because in such cases, a metric for what is more interpretable is either unavailable or there is a noisy, potentially incorrect, proxy for it which leads to mis-interpretations [6]. Thus, to rigorously test our hypotheses, we relied on tasks with known latent variables.
We hope that our detailed response has addressed the reviewer’s concerns and we would be happy to engage in further discussion to understand and resolve further questions. We hope that our response sufficiently highlights why the hypothesis we study is well motivated and interesting.
[1] Karl, F., Kemeter, L. M., Dax, G., & Sierak, P. (2024). Position: embracing negative results in machine learning. arXiv preprint arXiv:2406.03980.
[2] Liu, Aixin, et al. "Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model." arXiv preprint arXiv:2405.04434 (2024).
[3] Jaegle, Andrew, et al. "Perceiver io: A general architecture for structured inputs & outputs." arXiv preprint arXiv:2107.14795 (2021).
[4] Lewis, Patrick, et al. "Retrieval-augmented generation for knowledge-intensive nlp tasks." Advances in neural information processing systems 33 (2020): 9459-9474.
[5] Tishby, Naftali, and Noga Zaslavsky. "Deep learning and the information bottleneck principle." 2015 ieee information theory workshop (itw). Ieee, 2015.
[6] Doshi-Velez, Finale, and Been Kim. "Towards a rigorous science of interpretable machine learning." arXiv preprint arXiv:1702.08608 (2017).
The paper addresses a key question in in-context learning: whether explicit latent variable learning leads to better generalization, especially out-of-distribution (OOD). The conclusion is that the explicit bottleneck architecture does not help in terms of generalization.
给作者的问题
Update after the rebuttal, and after the direct message from the author to AC:
I appreciate the authors for conducting extensive experiments varying the number of parameters. I apologize for initially overlooking the additional results provided in the attached files. I have now carefully reviewed the full set of results and summarized the key information below:
-
For linear regression: result is mixed | Implicit Model | Matching Explicit Model in terms of param | OOD Query (which is better) | OOD Latent (which is better) | |------------------------|-----------------------------------|-----------------------------|-------------------------------| | 4 layers | N/A | N/A | N/A | | 6 layers | (explicit-MLP) L_context=4, L_prediction=4 | explicit | implicit | | 8 layers | (explicit-MLP) L_context=4, L_prediction=8 | explicit | implicit | | 8 layers | (explicit-MLP) L_context=6, L_prediction=4 | explicit | implicit | | 8 layers | (explicit-Tsf) L_context=4, L_prediction=4 | explicit | implicit |
-
For linear classification, sinusoid, MLP regression, MLP classification: implicit is better | Implicit Model | Matching Explicit Model in terms of param | OOD Query (which is better) | OOD Latent (which is better) | |------------------------|-----------------------------------|-----------------------------|-------------------------------| | 4 layers | N/A | N/A | N/A | | 6 layers | (explicit-MLP) L_context=4, L_prediction=4 | implicit | implicit | | 8 layers | (explicit-MLP) L_context=4, L_prediction=8 | implicit | implicit | | 8 layers | (explicit-MLP) L_context=6, L_prediction=4 | implicit | implicit | | 8 layers | (explicit-Tsf) L_context=4, L_prediction=4 | implicit | implicit |
-
For RAVEN: result is mixed | Implicit Model | Matching Explicit Model in terms of param | OOD (which is better) | |------------------------|-----------------------------------|-----------------------------| | 4 layers | N/A | N/A | | 6 layers | (explicit-MLP) L_context=4, L_prediction=4 | implicit | | 8 layers | (explicit-MLP) L_context=4, L_prediction=6 | explicit | | 8 layers | (explicit-Tsf) L_context=4, L_prediction=4 | explicit |
-
For gene, result is not interpretable since no explicit runs have similar number of params than implicit runs.
That said, under a similar total number of parameters, the implicit model generally performs better, though there are a few exceptions.
The following concern still remains, and I hope to see results averaged over more trials to clarify this point:
Additionally, I am somewhat confused. In the paper, the OOD task for linear regression is defined as OOD query. However, the new results suggest that for OOD query, the explicit model performs better when the latent parameters are learned explicitly under a similar parameter count. Specifically, in Figure 2, the blue block (representing the explicit model as Transformer) should correspond to the Explicit-Tsf model with L_context=4 and L_prediction=4. However, this seems to lead to a conflicting conclusion with what is now presented. I hope I’m not misunderstanding or misinterpreting anything here—perhaps this result is based on a single run, and the variance might explain the inconsistency. If that’s the case, it would be helpful to clarify whether these results are averaged over multiple runs or represent individual trials.
In addition, I would like to see more fine-grained explicit model configurations that match the number of parameters of the implicit models. Among the results provided during the rebuttal, only a few explicit setups are comparable in parameter count—one setup roughly matches the 6-layer implicit model, and one or two setups roughly match the 8-layer implicit model.
Overall, I see promising signals that the implicit model outperforms in general, especially on ID tasks. However, I would also encourage the authors to investigate fairer comparisons by matching ID performance—e.g., tuning the parameter count so that both models achieve similar ID results, and then evaluating on OOD. Currently, a lot of settings (GP, MLP regression) have implicit model with better ID than the explicit model, and the the source of the OOD gap needs to be investigated.
Lastly, I would suggest that the authors moderate the strength of their conclusions. Since it is difficult to cover all possible setups, it would be helpful to explicitly limit the scope of the findings and include a discussion paragraph acknowledging that the results are task- and architecture-specific.
At this point, I fully understand the authors’ efforts and frustrations during the intense rebuttal phase. I have updated my score to 4, and I hope the authors will seriously consider the comments above. Once again, I sincerely appreciate the authors’ engagement and detailed responses throughout the discussion.
论据与证据
Yes, the paper design experiments test the intended hypothesis well.
方法与评估标准
Yes.
理论论述
NA
实验设计与分析
- Is there any difference in the result when using an encoder (as in the paper) vs. using a decoder? Why did the author choose to use an encoder in particular?
- For the regression problem, why did you choose the OOD task to be scaled x, instead of scaled w, since by bottlenecking the latent variable, the explicit model should probably be better at any OOD of the latent variable? Could you check the other OOD task for the regression problem, especially the ones that have OOD on the latent variable?
- Could you explain how many layers of Transformer are used in the implicit and explicit model (including the context model and the predictor head)? Would there be any impact regarding the depth of the Transformer itself on the OOD generalizability? For instance, the deeper model may be better at OOD than the shallower model.
- I see from the appendix that you consider two options in the predictor, one is a 4-layer TF, another is the MLP. Could you explain clearly which model you used to report the result? And why do you think using a shallower Transformer for the context extraction is sufficient for the explicit model to generalize?
- By the ablation study on "explicit models learn to infer the correct latent variable, but not how to use it", the author suggests that by forcing the model to explicitly predict the correct latent variable, the explicit model generalizes better. I wonder if the effect of "dimension" in the output also matters. If I understand correctly, the y_q is one-dimensional, while the latent variable is 1 dimension for the linear regression, and 2 dimensions for the other tasks. Would it be possible to put the nonlinear & sinusoidal to also 1 dimension in x when conducting the Figure 3 experiment? This could potentially rule out the error in terms of the dimensionality.
- Also in Figure 3, why is the classification task not beneficial from predicting the latent variable directly?
补充材料
I check the experimental setup
与现有文献的关系
This paper discuss the key hypothesis that learning the latent variable is not beneficial to the overall OOD performance of the transformer. This is an interesting observation to understand how transformer learn to solve the synthetic problems.
遗漏的重要参考文献
Not that I aware of
其他优缺点
Strengths: The paper uses extensive experiments designed to investigate the hypothesis of whether learning an explicit latent variable helps in the transformer's generalization ability.
Weakness: It is in general hard to draw a fair comparison on the OOD experiments. For instance, as I mentioned in the experiment section, the author chooses 8 layers for implicit TF, and 4 layers for explicit TF when extracting the latent z, and MLP / 4 layer TF for the predictor. However, as far as I know, the depth of the model strongly affects the model's OOD performance in the regression task. It is not justified why the explicit model is designed like that (with 4 layers allocated for context extraction), and could be beneficial if there are more explicit model designs, for instance, the context extractor = implicit model depth, is tested to accompany the results.
其他意见或建议
NA
We thank the reviewer for acknowledging the value of our work and providing constructive cristicism. Throughout this comment, we will refer to additional experiments that are provided here: https://anonymous.4open.science/r/explicit-implicit-rebuttal-B263/explicit-implicit-rebuttal.pdf
Is there any difference in the result when using an encoder (as in the paper) vs. using a decoder?
The only difference is that a decoder would use a causal mask as opposed to our setting. Since we feed tokens as [] where defines a token, we cannot leverage a causal mask (refer to our response to Reviewer 1Akw). Our training loss, however, is an unbiased estimator of a similar next-token prediction decoder loss (modeling given [] for all in parallel) but more expressive since context points can attend in anti-causal direction as well. We chose this setup as it provides more supervision than feeding a token as either or , and is a common choice for many related works [1,2].
Could you check the other OOD task for the regression problem, especially the ones that have OOD on the latent variable?
We thank the reviewer for this great suggestion and already conduct similar experiments with OOD latents in Figures 2(c) and 5, where we test for compositional generalization instead of OOD queries. In these tasks, novel combinations of underlying latents are provided during inference as opposed to solely changing the query (see Training and Evaluation on Page 5).
Inspired by the reviewer's suggestion, we conduct additional analysis where we test for OOD generalization in linear regression and classification by changing the distribution of the underlying weight vectors to be sampled from a normal distribution with larger variance. Our results, provided in Table 1 of additional experiments, indicate that it is not the case that explicit models are able to generalize better than implicit ones in such OOD scenarios. Even further, we see that while known prediction is a bit better, it still lags behind the implicit models because this setting is OOD for the context aggregator while maintaining in-distribution .
Could you explain how many layers of Transformer are used in the implicit and explicit model
We use 8 layers in our experiments, where the explicit model splits prediction and context aggregation evenly (4 layers each). We conduct additional experiments where we compare an 8-layered implicit model to an explicit model with 8-layered context aggregation, and refer the reviewer to Figures 1 and 2 in additional experiments. Our results indicate that even with a larger context aggregation model, the same results hold.
I wonder if the effect of "dimension" in the output also matters
We point the reviewer to Figure 6 in the main paper which studies the role of dimensions. In general, the trend is consistent with increasing the complexity of the task, whether it be through the dimensionality of the input, the latents or the output. We also refer to the Figure 4 of additional experiments where we study the role of size of output dimensions in linear regression.
why is the classification task not beneficial from predicting the latent variable directly?
For the linear classification task, we believe that all the models have saturated their performance (note > 97%). For the nonlinear classification, when we fix the prediction function, there is an infinite set of latents that could lead to the same functional form but the space of the solution space is quite entangled and convoluted (note that this refers to permutation and scaling symmetries that leave the functional form unchanged). In contrast, it might be easier to explore alternate solutions by changing the prediction function to have a smoother landscape of possible latent variables.
We hope that our response has resolved the reviewers' concerns and would be happy to provide further clarifications.
[1] Hollmann, Noah, et al. "Tabpfn: A transformer that solves small tabular classification problems in a second." arXiv preprint arXiv:2207.01848 (2022).
[2] Müller, Samuel, et al. "Transformers can do bayesian inference." arXiv preprint arXiv:2112.10510 (2021).
Thank you for the detailed response. After reading it and reviewing the additional results, I have decided to maintain my score.
Reason for not a higher score: While I appreciate the added experiment comparing implicit and explicit models, I still find it difficult to draw a fully fair comparison in the OOD setting. In the new explicit setup, both the context aggregator and the predictor are implemented as Transformers, which increases the total number of parameters compared to the implicit baseline.
To better control for model capacity, one reasonable comparison would be to keep the context aggregator the same as in the implicit model and use a lightweight predictor such as an MLP. This would reduce the parameter overhead and make the setup more directly comparable to the implicit model. On the other hand, the current explicit setup—where the aggregator is identical to the implicit model and the predictor is the known function—can be viewed as an upper bound on performance when the latent variable is provided. From Figure 2 in the additional experiments, this setup achieves comparable or slightly better performance than the implicit model.
More broadly, if the explicit model consists of two components (aggregator + predictor) and the goal is to investigate generalization within Transformer architectures, then how the model capacity is divided between these components can significantly affect performance. Evaluating only a single configuration (e.g., a 50/50 split) does not rule out the possibility that other allocations (e.g., 60/40, 70/30, etc.) may yield better results. This sensitivity is likely task-dependent, as different tasks may benefit from different capacity allocations between aggregation and prediction. I acknowledge that conducting such a sweep is non-trivial and resource-intensive, which makes it difficult to draw a strong negative conclusion from the current results.
Reason for not a lower score: The paper presents extensive experiments and raises several insightful questions that contribute meaningfully to the understanding of task structure and generalization.
We thank the reviewer for engaging in discussion and hope that our response answers the questions and clarifies the concerns raised.
Through our rebuttal experiments on a larger context aggregation model, we had validated the hypothesis regarding explicit latent variable inference and OOD generalization. While in our original setup we did test with both a transformer or light-weight MLP as a predictor (refer to Figure 2 in the original paper), we agree with the reviewer that a more in-depth study of the sensitivity of task performance to the complexity of different model parts is important.
To alleviate the reviewer's concern, we run a large-scale analysis with different number of layers in the context aggregator and prediction module for both the implicit as well as the explicit models (both MLP and Transformer predictor). The results are highlighted here: https://anonymous.4open.science/r/explicit-implicit-rebuttal-B263/explicit-implicit-suppl.pdf, and show that across a suite of different tasks, explicit models do not show an improved performance over implicit ones, thereby further validating our hypothesis and making our claim stronger.
We thank the reviewer for their insight and would greatly appreciate an increase in rating if their concerns have been addressed.
This paper investigates whether explicitly inferring latent variables of an underlying task improves in-context learning performance in transformers. They find that explicit modeling of latent variables does not necessarily improve performance compared to standard implicit models. They also find that while the explicit model does learn latent variables, the main problem with generalization is its prediction function is not properly trained.
给作者的问题
My main question is related to the learning of a good prediction function.
-
Do you have any intuition as to why the prediction function is not trained well enough even though it appears like the latents are being modeled properly? Is it because the ID training data is sufficiently different from the OOD data? Perhaps it depends on the task.
-
It's a bit surprising that explicit training does not learn a "good" prediction function. Is there some way to algorithmically compare the learned prediction function to the optimal one? This might be a nice way of characterizing the failure modes. Would freezing the first half of the explicit model after a while and training only the predictor function approach the optimal one? (It seems like they should be mutually reinforcing in a sense)
-
In Line 136-142, left you mention you do not train with next-token prediction. Is there a reason this wouldn't work? Do you think this limits your results/claims in any way?
论据与证据
Yes, the authors are careful to make claims that are backed by sufficient evidence.
方法与评估标准
Yes, there are several appropriate datasets the authors use to show the generality of their findings.
理论论述
I did not check any proofs, but also did not see any theoretical claims in the paper.
实验设计与分析
The experimental design is solid. The experiments that are clearly explained, and questions I had about details while reading I was able to find in the appendix.
补充材料
Yes, I read the entire appendix.
与现有文献的关系
This work finds that causing the transformer architecture to explicitly infer latent variables of few-shot prompts is not sufficient for good ICL generalization. There has been a few works that suggest LLMs infer latent variables while doing ICL [1,2]. But this work showed little difference in performance between standard transformer and the explicit model. This is an interesting data point among previous and related findings. For instance, while [3] & [4] provide evidence suggesting that LLMs trained on natural text seem to do implicitly model latent variables, one contribution of this paper might be that it provides a way to perhaps understand the failure modes of these models - that it can similarly be attributed to a poor "prediction function" which seemed to be the problem with the failure cases of the explicit model in this work. There are some nice contributions of this work that might help us understand the architectural and algorithmic shortcomings of the models we use.
[1] Xie, et al. An explanation of in-context learning as implicit bayesian inference. ICLR 2022. (https://openreview.net/forum?id=RdJVFCHjUMI)
[2] Wang, et al. Large Language Models Are Latent Variable Models: Explaining and Finding Good Demonstrations for In-Context Learning. NeurIPS 2023. (https://proceedings.neurips.cc/paper_files/paper/2023/hash/3255a7554605a88800f4e120b3a929e1-Abstract-Conference.html)
[3] Hendel, et al. In-Context Learning Creates Task Vectors. EMNLP 2023. (https://aclanthology.org/2023.findings-emnlp.624/)
[4] Todd, et al. Function Vectors in Large Language Models. ICLR 2024. (https://openreview.net/forum?id=AwyxtyMwaG)
遗漏的重要参考文献
This work is well-placed among previous and concurrent work. I wanted to point out two works that I think are related, but were not cited, which study the idea of in-context learning centering around latent variables [2,5]. [2] support the bayesian inference of an internal latent view of in-context learning. The work in [5] shows an example of how simple implicit transformer models seem to learn latent variables like others suggest is happening in larger LLMs. Maybe a discussion of how [5] relates to this paper's study of implicit vs. explicitly modeling latent variables would be helpful to provide more context to how we might interpret both results.
[2] Wang, et al. Large Language Models Are Latent Variable Models: Explaining and Finding Good Demonstrations for In-Context Learning. NeurIPS 2023. (https://proceedings.neurips.cc/paper_files/paper/2023/hash/3255a7554605a88800f4e120b3a929e1-Abstract-Conference.html)
[5] Han, et al. Emergence of Abstractions: Concept Encoding and Decoding Mechanism for In-Context Learning in Transformers. (https://arxiv.org/abs/2412.12276)
其他优缺点
This is a solid paper. The proposed setting of explicitly modeling latent variables is nice, and the experiments are carefully designed to test specific hypotheses. The results are thoughtfully presented without overclaiming and are supported by evidence across a variety of tasks.
The main weakness I'd say of the paper is that since prediction function failure was the main reason the explicit model didn't work, there could be more discussion about how to improve this during training beyond the current text: (i.e., "supplemented with significant inductive biases in the prediction function"). Could you provide some examples of what this might look like for different tasks? Does next-token prediction (e.g. language modeling) have sufficient inductive biases that might mediate some of these problems, or do you think it's purely architectural?
其他意见或建议
Here's a list of minor typos I found while reading through:
- Line 131: "shown gray" -> shown in gray
- Line 782: "Hodgkin-Hoxley" -> Hodgkin-Huxley?
Note after Rebuttal period: As before the rebuttal, I am leaning towards accepting this paper. This paper does have its limitations, but the things I learned from this paper I think outweigh any reservations I may have had. The rebuttal by the authors answered my questions, and I feel they have also addressed the concerns of the other reviewers as well.
We thank the reviewer for acknowledging the value of our work and providing constructive cristicism.
Additional References
We appreciate the reviewer bringing two relevant papers to our attention and will include them in the final version.
- Wang et al. (2023) shows that inferring latents ( in their setup) is helpful but relies on a finite set of tasks and latents that are shared across different batches of observations. In contrast, our approach considers an uncountable set of tasks. They also assume data beyond the current context that belongs to the same task, while we only consider the current context as defining the latent.
- Han et al. (2024) show that performance on in-distribution (ID) tasks highly depends on how well the task latent is encoded. However, they also consider only finite set of tasks and do not focus on OOD generalization. In their language, we show that if the concept decoding part of the model doesn't have the proper inductive biases, it will only learn how to use the encoded concepts ID.
While we show that explicit models do not outperform implicit ones even though they correctly infer the latents, it could mean either the implicit model uses a different mechanism or infers the latents in a potentially distributed, uninterpretable way depending on the task. Answering this is an important future work, and the reviewer points out relevant works that aim to do so in certain settings.
why the prediction function is not trained well enough even though it appears like the latents are being modeled properly
Our hypothesis is that a learnable prediction function doesn't have appropriate inductive biases, either coming from its architecture or the training data. For example, a MLP prediction function can learn to be linear in the training regime, while being arbitrarily nonlinear outside, leading to suboptimal OOD performance. A similar argument can be made for other tasks. We will add a discussion about this in the draft.
algorithmically compare the learned prediction function to the optimal one
We refer to Figure 11 for an example of a learned prediction function evaluated away from the training distribution. In addition, we refer to Figure 4 (https://anonymous.4open.science/r/explicit-implicit-rebuttal-B263/explicit-implicit-rebuttal.pdf) which illustrates the performance of explicit models with MLP prediction OOD on the sinusoid task.
you mention you do not train with next-token prediction. Is there a reason this wouldn't work? Do you think this limits your results/claims in any way?
We could have trained our models using next-token prediction by feeding in tokens as [], where arguments inside constitute a token. In this case, we could have done an augmented version of next-token prediction where we only consider losses corresponding to tokens (predicting the next query point is non-informative since iid samples). However, this requires the model to additionally learn which corresponds to which . We instead feed [] so the model does not need to learn that information, it is provided as input. However, this prohibits learning via next token prediction solely computationally and not algorithmically. In practice, our training loss is an unbiased estimator of predicting given [] for all in parallel, while being more expressive by allowing anti-causal communication within the context. This choice of modeling is common across a number of related works [1,2].
Additional insights into inductive biases to improve prediction function
We believe that there are a number of directions for inductive biases that could be worth pursuing. One direction is the architectural design of the prediction function, with or without task-specific knowledge baked in (e.g. using as the predictor for sinusoidal regression leads to perfect OOD generalization). Alternately, one could also look at optimization strategies (eg. alternate optimization instead of jointly optimizing, freezing one part of the network, etc.) that could lead to better convergence of the prediction function. As the reviewer rightly points out, next token prediction also provides an inductive bias towards this goal. We defer an in-depth analysis into the inductive biases as well as improving the prediction function as future work.
We hope that our response has resolved the reviewer's concerns and would be happy to provide further clarifications.
[1] Hollmann, Noah, et al. "Tabpfn: A transformer that solves small tabular classification problems in a second." arXiv preprint arXiv:2207.01848 (2022).
[2] Müller, Samuel, et al. "Transformers can do bayesian inference." arXiv preprint arXiv:2112.10510 (2021).
The paper initially received mixed reviews.
- Reviewers 1Akw and TjQ5 were strongly supportive. However, TjQ5's review lacked detail, and the reviewer did not participate further, such as by providing clarifications or joining the discussion. Therefore, I downweighted Reviewer TjQ5's score accordingly.
- Reviewer 1soF is leaning positive but critical of certain comparisons.
- Reviewer gppE found the results unsurprising and despite the rebuttal thinks that the paper does not meet the ICML standards.
- Reviewer vn7o found the setup and main message unclear, even after the authors' rebuttal.
In the reviewer discussion phase, reviewers 1soF and vn7o again expressed their concerns. Following a note by the authors indicating that 1soF might have overlooked recent results, the reviewer updated their score and now recommends acceptance.
While reviewers gppE and vn7o have expressed concerns, the majority recommend acceptance. I suggest that the authors address the concerns raised by these reviewers in the final version of the paper, should it be accepted.