How Transformers Learn Structured Data: Insights From Hierarchical Filtering
摘要
评审与讨论
The authors propose a synthetic way to evaluate how Transformers learn interactions on trees that are generated with different positional correlations. They show that the Transformer learns to approximate the algorithm used to generate the synthetic data. They show that the Transformer learns "longer range" interactions in the deeper layers and more local structure in the earlier ones.
给作者的问题
Regarding the weakness, why do you find the result that the model seems to align itself with the way the data is generated surprising? Is this not immediate from the fact that you have trained a model that manages to fit the data well using SGD?
Are you suggesting that a Transformer could learn to fit this data in some other way? While I find this result still experimentally interesting I do not necessarily find it very surprising.
Regardless, I still find the paper interesting and the methodology to make sense, for this reason I am leaning more towards acceptance.
论据与证据
I think the claims are very clearly backed up with convincing evidence.
I believe that the title "How transformers learn structured data" is perhaps a bit of a strong claim. While the paper goes in this direction, it sounds like a pretty sweeping claim that this is what the paper solves. The paper is studying the behaviour in a very synthetic and controlled system. Of course I understand the need to study these kinds of behaviours under such constraints, but I find the title strongly worded.
方法与评估标准
The evaluation criteria is solid.
理论论述
There are no theoretical claims.
实验设计与分析
The experimental design is in fact a contribution of the work and I believe it is an interesting contribution.
补充材料
I have not reviewed the supplementary material.
与现有文献的关系
As this is slightly separate to my main area of research, I am unsure how it relates to literature in the surrounding area. There are however a number of studies (that are relatively dated) that show that machine learning models learn in deeper layers more complex dependencies. In some sense I feel like this work goes in this direction as well.
遗漏的重要参考文献
I am not very familiar with this exact kind of literature that aims to learn probabilitistc models with Transformers so I cannot comment on essential references that are not discussed.
其他优缺点
Strengths: I think the setup is quite clear and the results make sense. The authors provide an extensive amount of experiments
Weaknesses: The main weakness I see is that I do not find the conclusions particularly surprising. In fact the claim "which provides evidence of an equivalence in computation to the exact inference algorithm." I would assume follows the the learning process was successful. In the sense that if your loss is low during SGD, I would imagine that the model has learnt to align itself with the underlying algorithm that generates the data.
其他意见或建议
This is a style suggestion but I would put the bibliography before the appendix as this is most standard as far as I am aware.
We thank the reviewer for their valuable feedback, and address their comments and questions.
On the referee’s concerns towards the ‘surprise’ of our results: Let us clarify why we believe that our findings are not trivial. While one of the paper’s conclusions is indeed that, in the end, the transformer learns what it’s trained to learn, we think that the novelty and surprise lie in uncovering how it does that. In particular:
- The model shows excellent generalization performance, pointing to the fact that it has learned the complex data model without overfitting the training set, which is not trivial.
- It does so in a way that almost perfectly reproduces the output of the exact algorithm at the logits level, without ever being explicitly trained to do so, as it is only given hard labels in training and no calibrated supervision. This demonstrates that the model closely mimics the exact algorithm, even on entirely out-of-sample inputs.
- It shows some very interpretable characteristics that are closely related to the natural implementation of Belief Propagation that we propose in the appendix, spontaneously organizing the computation in a hierarchical fashion, progressively going up the tree, which is not a priori required in our overparametrized context.
- It progressively includes higher and higher levels in the hierarchy during training instead of e.g., suddenly ‘grokking’ to the optimal outcome after having seen enough examples.
Points (2), (3), and (4) go beyond what the model is trained to do via SGD, which just points at predicting the right class label or a masked symbol. We think that this highlights our contribution towards model interpretability and a genuine understanding of how and what the architecture learns beyond pure performance; as argued in the response to referee ivGY, an important aspect of our work is mechanistic interpretation. Note that points (3) and (4) also allow us to make contact with recent works on ‘simplicity bias’ in successful machine learning architectures; see also our answer to referee Rk87. Finally, in our context, we can also understand why, e.g. an insufficiently trained model would provide sub-optimal performance due to it incorporating only some of the spatial correlations in the data to make its prediction, which we believe is a rather rare occurrence in a complex data setting such as this one.
Transformer architectures have become highly successful in deep learning, achieving state-of-the-art performance in various NLP and computer vision tasks. However, it is still not fully understood how transformers learn from different types of data. This paper takes a step toward a better understanding of transformers' ability to learn from structured data, specifically hierarchical structures, through attention heads. The paper first introduces a complete hierarchical generative process, starting with a root symbol and iteratively generating child nodes using a probability function governed by a transition matrix. A filtering mechanism is applied by introducing a parameter k, where k=0 means that all pairs of child nodes are conditioned on their respective parents, implying strong correlations. For k>0, the children at level k are generated conditionally with respect to the root. This process of generating sequences allows for exact inference using a dynamic programming approach. In the second step, the paper selects an encoder-only transformer model with l layers. The goal is to demonstrate that the encoder-only transformer can approximate the exact inference algorithm when trained on root classification and masked language modeling tasks. In the root classification task, each generated sequence is labeled with the root. The results show that, in any combination of training and testing with k=0 or k=0, 1,..., l, the transformer achieves the same accuracy as the Belief propagation (BP) method and approximates its inference. Similar results are observed in the masked language modeling task, where the model is pretrained by learning to predict masked tokens, optimizing a loss function. Visualization of the attention matrix further supports the claim that transformers can learn hierarchical structures effectively.
给作者的问题
What is your opinion on structure learning by transformers from a broader aspect, like motif learning?
论据与证据
The paper addresses its claims in several ways: experiments in two scenarios of root classification and MLM, and comparing the performance of transformers with BP. In addition, visualization of the attention layer illustrates the way hierarchy is captured by transformers. However, the approach is limited to one types of hierarchy and single dataset.
方法与评估标准
Despite being limited to one hierarchy type, the approach follows reasonable steps.
理论论述
One of the main drawbacks of the paper is lack of theoretical analysis, given the fact that the experiments are narrow in terms of data type.
实验设计与分析
The experimental results are valid and cover several aspects of the problem.
补充材料
I skimmed through the appendix, and it seems quite comprehensive. However, there are several parts of the paper that repeatedly refer to the appendix, which could be considered as an issue.
与现有文献的关系
The idea presented in this paper is closely related to graph-based machine learning and foundation models for graphs. Most existing graph foundation models rely on GCNs or graph transformers as their main components, which may not be as scalable as traditional transformers. This work highlights the potential of using transformers for structure learning, thus can draw attention to this promising direction in the field.
遗漏的重要参考文献
Nan
其他优缺点
– The core idea of the paper is interesting and helpful for the community as proving the structure learning capability of a transformer can lead to developing a general foundation model that learns from both structured and unstructured data.
– the paper only studies a very narrow type of structure and at the same time lacks any theoretical justification.
– for a reader who is not fully aware of the literature, it is a bit hard to understand the introduction. Thus, the writing can be improved.
其他意见或建议
If possible, studying other types of hierarchies would be interesting.
We appreciate the reviewer’s feedback and address here the weaknesses and questions they have raised.
- On the weakness about the lack of theoretical analysis: We are indeed unable to derive precise analytical results in our paper (as the complexity of our data models implies that we do not even have access to a closed-form expression of the distribution). However, we would argue that our setting does carry very significant theoretical justification and control to understand what is (or is not) being learned. Having access to the optimal oracle in the form of the full BP algorithm, as well as all its factorized counterparts, is indeed what allows us to show how the transformer progressively learns to solve the task during training, and that it does so with correctly calibrated predictions. Moreover, it is the knowledge of the exact inference algorithm that allows us to propose a plausible implementation within the transformer architecture and to verify that what is being learned in practice is compatible with it. As such, while we do agree with the referee that the type of structure we consider is indeed quite narrow, our choice is motivated by solid theoretical grounding. As argued in the response to referee Rk87, our work should be taken in the context of mechanistic interpretation, for which the state of the art is centered on simple tasks such as histogram counting. Our setting offers an important stepping stone toward the understanding of models dealing with complex data and lies closer to natural language processing, where transformers are ubiquitous but not well understood (even in simplified models of language, see the discussion on context-free grammars with reviewer Rk87).
- Clarity of the introduction: We are definitely interested in feedback in order to improve the clarity of our work towards a wider audience, and thank the referee for pointing it out. If possible, we would greatly appreciate precisions from the referee regarding which part of our introduction was particularly difficult to understand. We assume that it is with regards to results around Context Free Grammars (i.e. second paragraph), is this correct? If so, we would be glad to attempt a rewriting in the next iteration of reviews.
- On the possibility of working with data with other types of hierarchies: We agree with the referee that leveraging a similar type of analysis beyond fixed topology binary trees would be of great interest. As mentioned in our conclusion, we believe this is a clear direction for future work. Indeed, our current paper is, in our opinion, already quite dense for the 8-page format of ICML.
- On the link to broader structure learning: We thank the reviewer for pointing towards motif learning, which we were not familiar with. It indeed appears that our hierarchical model may be an interesting setting to explore this idea, as blocks of symbols of common ancestors can naturally be interpreted as motifs. Therefore, the fact that in our model learning essentially takes place through the identification of larger and larger clusters, to reconstruct higher and higher levels of the hierarchy, does support the idea that motif learning may facilitate sequence memorization for instance. We will add a mention to some references on the topic, such as Wu, S., Thalmann, M., & Schulz, E. (2023), in our conclusion for the next iteration of our paper.
This paper investigates how a vanilla transformer encoder learns to infer latent hierarchical structure from data. The authors introduce a synthetic hierarchical tree-structured data model with a tunable filtering parameter that controls the depth of correlations in the sequence. Using this controlled setting, they train vanilla encoder-only transformers on two tasks: (i) predicting the hidden root label (root classification) and (ii) masked token prediction (MLM), and compare the models’ behavior to the optimal Bayesian inference (belief propagation, BP) on the tree. The key findings are that transformers achieve near-optimal accuracy on these tasks and produce well-calibrated probability predictions closely matching the BP oracle even on novel inputs. The network appears to learn hierarchical dependencies gradually during training, first capturing short-range (local) correlations and then longer-range ones.
给作者的问题
- Did you observe the sequential learning of hierarchical correlations across multiple training runs or random seeds? In other words, is this progression reliably reproducible, and how sensitive is it to factors like learning rate or initialization?
- What would happen if the number of transformer layers, , is set larger than needed (i.e., over-parameterised)? This is interesting since, as seen in works like Mixture-of-Experts, not all parameters may be necessary to solve a given task.
- How do you ensure that the leaf nodes are order-dependent? In natural language, the grammar tree depends on the actual words at the leaves, but in the synthetically crafted model, there is no inherent guarantee of order dependency. Could you elaborate on how this aspect is handled or justified in your framework?
论据与证据
The central claims appear to be generally supported by convincing evidence, including: (i) a transformer can approximate the exact tree inference algorithm (BP) and produce calibrated posterior probabilities, and (ii) transformers learn hierarchical correlations in a progressive manner during training.No major claims appear unsupported.
方法与评估标准
The methods and evaluation setup are appropriate and well-designed for the research questions.
理论论述
This paper is primarily empirical and appears to build upon established theoretical frameworks.
实验设计与分析
No particular issues found.
补充材料
I did not run the code provided in the supplementary material but tried to check if the code aligns with the vanilla transformer implementation in the paper. No obvious issues found.
与现有文献的关系
This work lies into the literature of the interpretability of transformers on structured tasks, which has been explored in the context of formal languages and syntactic trees
遗漏的重要参考文献
Not found.
其他优缺点
This work distinguishes itself by presenting a controlled framework that manipulates hierarchical structures through a tunable filtering parameter. However, a limitation is its divergence from the complexities encountered in real-world transformer applications, such as those used in language models. Expanding the discussion to explore potential implications and applications in practical settings could further strengthen the contribution. Additionally, the paper does not compare the transformer’s performance to alternative approaches (besides the BP oracle). For instance, could a simple feed-forward network that takes all tokens as input solve the root classification task? Or might an RNN, such as an LSTM with sufficient capacity, also approximate BP? Including such comparisons would help emphasise the unique advantages of the transformer architecture.
其他意见或建议
The term "structured data" in the title is quite broad and can imply a wide variety of complex, real-world patterns. However, in this work, the focus is on a synthetic binary tree scenario with a tunable filtering parameter.
We thank the reviewer for their feedback, and answer the points that they have raised.
- On the weakness point about the data being far from real-world one: We do agree that the data we used is far from, say, natural language. However, the complexity of real data strongly limits one’s understanding of what the model does given the lack of an objective ground truth. By putting ourselves in a simplified and controlled setting, we believe we uncovered a nontrivial way transformers can learn (see also the responses to reviewers Rk87 on context-free grammars and ivGY on the theoretical grounding of our work).
- Regarding the possible comparison with other models apart from BP: We agree that a comparison with other machine learning models is possible and interesting, but believe it would fall in the category of performance comparison. While there is no reason to believe that other architectures could not optimally solve the problem given enough data, the attention mechanism offers a significant advantage in terms of mechanistic interpretability, which is one of the central focuses of our work. Moreover, the transformer architecture has been chosen for its relevance, as it is ubiquitous in applications toward the analysis of sequences (text, amino acids…), and is known to be able to effectively implement algorithms. Note that, as mentioned in the response to reviewer Rk87, the works by Wyart and co-workers and Mei have demonstrated the ability of CNNs to implement belief propagation in similar hierarchical data models. However, the fact that the architecture and its convolutional filters mirror the tree structure of the data model limits the generality of their findings. Finally, note that the versions of BP obtained from the factorized graphs can be seen as other approximate algorithms that we compared the transformer to. In the root prediction task, for instance, the fully factorized BP corresponds to the well-known Naive Bayes estimator.
- On the reviewer’s comment about the term “structured data”: We in fact agree with them, although we had hoped that the subsequent precision towards “hierarchical filtering” clarified the more narrow context of our work. Nonetheless, we would be willing to change our title for e.g. ‘structured sequences’ if they believe that it better describes the scope of the paper.
We now answer the specific questions formulated by the referee:
- Robustness across different random seeds and model weight initializations: We indeed did experiments with several different random seeds and we saw no qualitative differences. Given the large number of training epochs, we did not find particular sensitivity to learning rates and initializations, and did attempt efficient learning rate schedules without finding any significant differences. As our paper already includes a large number of figures, we did not think it was judicious to also include learning dynamics for different instances, but we are open to including them if the referee deems it necessary.
- Using larger than needed number of layers: Experiments were carried out with up to 6 transformer layers (for ), yielding the same qualitative results in terms of training dynamics and sample complexity. Taking though leads to the most interpretable attention maps, which can be seen in Fig. 4. There, we indeed show the attention maps resulting from models trained on -filtered data models, that are trees with layers. As such, all intermediate cases in Fig. 4 fall into this category of larger-than-needed number of transformer layers. As shown in the maps, what occurs is that the computation may be "diluted" over more layers than necessary, while some attention maps remain close to unused due to the presence of skip connections in the architecture. Consider e.g. the second-to-last row of Fig. 4: the required mixing is carried out in the first two transformer layers, while the last two do not contribute.
- Order dependence of the leaves: The hierarchical model producing the data is in fact fully order-dependent, as the parent-to-children production rules are described by a transition tensor that is not symmetric (with overwhelming probability). Our data therefore behaves like natural language, in which order counts. On the transformer side, we employ standard positional embedding to explicitly add this information to the representation of each leaf, allowing the architecture to accommodate for the order dependence of our model just like in natural language processing.
The paper studies how simple transformer models learn to perform root and leaf inference (corresponding to classification and masked-language modeling tasks) on a synthetic generative hierarchical model of data on a regular tree of depth , belonging to the class of context-free grammars. For such a model, exact inference can be done using belief propagation (BP). The paper demonstrates that (i) Transformers learn the same marginals as BP; (ii) Increasingly deep levels of the grammar are learned sequentially by the transformer as training time increases; (iii) Probing experiments suggest that the transformer reconstructs the grammar’s structure across its layers. Finally, the authors (iv) propose a theoretical implementation of BP in an -layer transformer architecture.
Update after rebuttal
Given the rebuttal and the additional results provided on the efficient implementation of the belief propagation algorithm within the considered transformer architecture, I have raised my score from 2 to 3 and now lean toward acceptance. I did not raise the score to 4, as I still find several of the paper’s contributions to be largely incremental.
给作者的问题
- Can you comment on the discussion above on BP vs inside-outside and the sufficiency of layers in contrast to previous approaches?
- Can you elaborate more about how you place your work among the existing literature and what you think are the most important novel scientific contributions for publishing it at ICML?
论据与证据
The claims in the submission are convincing and scientifically supported. The paper presents both theoretical and empirical evidence for transformers implementing BP-like inference.
方法与评估标准
The methods are well justified. The use of a synthetic model mode data enables controlled and interpretable experiments.
理论论述
The paper provides a theoretical implementation of BP inside an -layer transformer architecture, which appears to be mathematically correct.
实验设计与分析
The experiments are scientifically sound and adequately support the paper’s claims.
补充材料
I reviewed the supplementary material of the paper.
与现有文献的关系
The paper is of incremental nature with respect to previous literature in the area. Similar findings about learning probabilistic graphical models on -level trees with the transformer have been obtained for CNNs (Cagnetta et al., 2024; Mei, 2024; Cagnetta & Wyart, 2024). While the paper provides insights into transformers approximating BP, the authors should clarify whether their work offers new insights beyond previous CNN-based results or confirms those results in a different architecture. Cagnetta & Wyart (2024) also study numerically and theoretically the progressive learning of hierarchical correlations with transformers, although as a function of training points , while the present submission empirically studies the problem as a function of training time . Finally, Zhao et al., 2023; Allen-Zhu & Li (2023) study theoretically and empirically whether and how transformers implement the optimal inference algorithm when learning context-free grammars on non-regular trees (i.e., the inside-outside algorithm).
遗漏的重要参考文献
N/A
其他优缺点
Other strengths:
The paper is well written and organized. It considers an interesting and timely problem, namely how machine learning models learn the hierarchical structure of data such as language.
Other weaknesses:
One of the paper’s main novel contributions is providing an implementation of the optimal inference algorithm for the considered model, i.e., BP, in an -layer transformer architecture. Previous results considering the implementation of the inside-outside algorithms – a generalization to BP that relaxes the assumption of fixed-tree topology – were requiring more layers. However, it remains unclear if the new proposed construction is more efficient only because of the fixed-topology assumption, which is then unrealistic in practice.
Furthermore, the authors claim their solution to be “efficient”. However, from Appendix F, it seems that they do not control the number of neurons required in the MLPs to update messages, just leveraging the universal approximation property – which can, however, require an exponentially large number of neurons in the input dimension.
其他意见或建议
The appendix should be moved after the references.
We thank the reviewer for carefully reading our work and providing valuable feedback.
On the efficiency of the BP implementation within the transformer and the MLP role: We had omitted to include an additional point in the Appendix, which is a precise proposition for performing the update of Eq. 22 through a two-layer fully connected network with hidden neurons. We have added back the explanation in the revised version, which we cannot yet reupload. Unfortunately, the character limit prevents us from detailing the construction here, but should the referee request it we could add it in our next reply.
We now answer their two questions directly.
- The reviewer is correct in pointing out that here BP is equivalent to a simplification of the inside-outside algorithm (IO) for context-free grammars (CFGs) on a fixed topology, and is not more efficient per se. We argue that this is a feature of our setting, as this difference importantly leads the algorithmic complexity of optimal inference to be linear in the sequence length, whereas it is cubic for the IO. As pointed out in Khalighinejad & al. (2023), this cubic complexity means that there must be some approximation in the transformer implementation of the IO, self-attention having a complexity that is only quadratic in the sequence length, or the network depth must be scaled with the context length (see the implementation proposed in Zhao et al. (2023)). Therefore, the evidence brought forth by Zhao et al. (2023) and Allen-Zhu & Li (2023) points toward transformers learning something very close to the IO for CFGs, but there is still a major open question as to what they do precisely. Relative to CFGs, the fixed topology is advantageous as it allows us to understand more precisely how transformers closely align with the exact inference algorithm, notably by leveraging the filtering procedure described in our work, which is not easily generalizable to CFGs. While we agree that our assumption is not as realistic in practice, it thus allows us to go much further in the mechanistic interpretation of the transformer. On the other hand, we believe our setting is at least as realistic as many mechanistic interpretation works (considering, e.g., histogram counting on integers).
- We believe that our work is at the crossroads between different bodies of literature. First, as described above, it studies a complex yet completely controlled task that allows mechanistic interpretation. Second, it builds upon the body of work of Wyart et al. on hierarchical models with, in our opinion, a central improvement through the introduction of the hierarchical filtering procedure. This filtering uniquely allows us to study the learning dynamics and therefore to also understand how insufficiently trained architectures might fail—here using an incomplete correlation structure. In doing so, it allows us to make contact with the expanding literature on staircases in learning dynamics and simplicity biases in machine learning architectures (Refinetti et al., 2023; Rende et al., 2024; Bardone & Goldt, 2024). In a nutshell, we believe that our most significant scientific contribution is to provide a truly comprehensive study of how a markedly non-trivial task, which shares similarities with practical natural language processing problems, is implemented in transformers.
Finally, on their comments:
the authors should clarify (...) new insights beyond previous CNN-based results (…)
In these CNN architectures, the mechanistic interpretation is somewhat trivial, as the convolution filters are made to mirror the tree structure. This is not the case in the transformer architecture, which has to learn this structure. We demonstrate that it implements it incrementally through its attention layers and progressively in training. Transformers are also ubiquitous for sequence-like data (text, amino acids…) whereas CNNs are seldom employed in this context, highlighting the importance of understanding how transformers digest long-range correlations.
Cagnetta & Wyart (2024) also study (...) theoretically the progressive learning of hierarchical correlations with transformers (...) while the present submission empirically studies the problem (...)
While the study of Cagnetta & Wyart (2024) is very interesting, we would highlight that, as it relies on a signal-to-noise ratio analysis that is tractable for their uniform transitions, it does not provide predictions for any specific architecture. While more empirical in spirit, we would again emphasize that our filtering strategy allows one to probe the behaviour of transformers, be it as a function of the sample complexity or of the training time. We therefore believe that the two approaches are complementary and not redundant.
Thanks for your response.
Since the authors mention that the omitted parts are already prepared, I would encourage them to share them at this stage.
We thank the reviewer for their interest in our proposed implementation.
A possible, non-parsimonious way to perform the update of Eq. 22 with a two-layer fully-connected network with hidden neurons is the following.
In the first layer, one can readily select the appropriate entries in the embedding vector to output the following three terms for all pairs :
Then, for each transition , the argument of the sum in Eq. 22 can be obtained as:
The trace over and is then performed by the second layer of the fully-connected block. For each transition, it reads the three corresponding hidden units and multiplies them by the same learned weights (using the appropriate positional embedding entry), while the summation is done as usual. In practice there are actually only non-zero of such weights for the transition tensors we consider. Note that this exact operation would require squared activations, but can be approximated with a network via a piecewise linear approximation. The other updates that are to be carried out by the MLPs (Eqs. 24-26) have the same reweighted sum structure, and can therefore also be approximated with still hidden units implementing the correct transition rates.
We hope that this proposed implementation will convince the reviewer that there is no need for an exponential number of neurons in the MLP to approximate the Belief Propagation algorithm in a transformer, and thank them again for their question.
The authors build a fully‑controlled playground in which sequences are generated by a hierarchical probabilistic grammar on a fixed binary tree. Since exact inference on this model is belief‑propagation (BP), one can compare transformers to a known oracle and propose an explicit BP‑in‑attention construction. Two tasks: root classification and masked token prediction show that a 4‑layer vanilla encoder trained by SGD achieves BP‑level accuracy and produces calibrated logits that are similar to the oracle on unseen data. Training‑time probes reveal a staircase dynamic: the network first captures leaf‑level correlations, then progressively longer‑range dependencies. Attention maps make this process visible, and the authors’ hierarchical filtering technique (artificially limiting which levels carry correlation) confirms that deeper layers specialise to deeper tree levels.
The paper presents a cohesive mechanistic interpretation by combining oracle comparisons, training dynamics, layer-wise probes, and explicit algorithmic hypotheses aligned with attention patterns. However, findings rely on fixed-topology trees and no benchmarks on alternative architectures (CNNs, RNNs) are provided, limiting the demonstrated transformer-specific advantage. Initial reviewer concerns about novelty and efficiency were mitigated by additional clarifications and experiments provided in the rebuttal. Reviewers appreciate the extensive experimental analysis and useful insights, despite the synthetic scope, making the paper suitable for acceptance.