Early learning of the optimal constant solution in neural networks and humans
摘要
评审与讨论
This paper proposes that the early-time learning of a constant solution that is input-independent and a pure function of the label distribution is a universal property of learning dynamics in neural networks and humans. It proves that linear networks with biases learn this early-time OCS and slightly more complicated and nonlinear networks exhibit qualitatively similar learning curves. Such learning dynamics are characterized by an initial dip in TNR for top levels of the output hierarchy in a hierarchical learning task, and such learning curves are also present when humans are trained and tested on such a task. Altogether, the paper seeks to provide theoretical and empirical evidence that early-time mimicry of the label distribution may be a general property of learning systems seeking to minimize error on a task.
优点
A major strength of the paper is that it is original, and makes claims about a broad range of phenomena. Generally, I think there isn't enough work studying how learning systems -- both artificial and natural -- exhibit qualitatively similar behaviors in various learning regimes. This paper studies linear networks (with and without bias), CNNs, and humans, and its contributions span from machine learning theory of linear networks to human psychological experiments, which is a breadth that is rare as well as nice to see. If it is true that the early time learning of the OCS is shared between a broad range of artificial networks and robustly replicated across humans, this would be an interesting and significant result. The inclusion of human data that qualitatively supports the hypothesis is a major strength.
缺点
The paper is quite confusing in terms of how it is laid out, and what the scope of which claims are. In my opinion, it tries do a bit too much. For instance, tensor trickery related to the NTK feels out of place/a low order concern when you didn't even take the time to define the hierarchical learning task clearly. I had to look through the Appendix as well as (Saxe, 2019) to get a sense of the task. Some specific weaknesses:
- It's not clear the scope of the claims being made. Either you are 1) remarking that the learning curves between linear, nonlinear networks, and humans on a certain class of task look similar but the mechanistic reasons are not well understood, and this is interesting and warrants further study or 2) making a claim that ALL neural networks are marked by this early-time behavior, and they are driven by mechanisms you identify. I would be more sympathetic to the contributions of this paper if you were making claim (1), but since you use the words "universal" and "mechanistic" in your abstract, I take it you are making claim (2). If this is the main claim, then the evidence presented is inadequate as justification. This is for the reasons that follow.
- You spend much of the paper talking about linear networks with biases, but in Section 6 make the (extremely strong) claim that such reversion can equivalently arise from input correlations. If it is indeed true that very general covariate structure can give rise to this behavior, then why have you spent most of the paper reasoning about the contrived and simple setting of linear networks with bias? Phrases like "might hence be ubiquitous beyond these datasets" arouse suspicion. Also, how can I believe that early-time learning of OCS is "universal" when it vanishes in even linear networks if you ablate the biases?
- I put very little weight on experiments that work with MNIST and CIFAR. One can conjure up evidence for any hypothesis by studying a suitable version of such datasets with some contrived architecture. They are simple tasks that might have been useful 10 years ago but certainly do not convey any information in 2024, in my opinion. I think this paper is at its strongest when discussing linear networks where claims are made precise, or human studies where the empirical trends found in these (interesting and novel!) experiments can be unpacked. I think trying to convince a reader that this is a universal property by showing toy experiments "in the middle" of linear/humans (eg. CNNs) is not a good strategy, especially since by the end I am not even convinced the mechanism for this phenomenon is the same across even just linear networks and humans. Spending some time providing evidence for this latter claim and dropping any attempts at universality claims justified only by MNIST/CIFAR, would in my opinion be a good idea.
- You remark that "characteristic learning signatures" arising from early time learning of the OCS involve a sort of step-wise learning structure. While I'm aware a variety of toy synthetic settings studied by theorists exhibit this behavior (see [7], for instance), no non-trivial data (ImageNet, C4) gives rise to dynamics that look like this (learning curves smoothly go to low loss rather than being step-wise). So certainly the word "universal" does not describe these types of learning curves that are "characteristic" of early-time learning of OCS.
All in all, I find the presentation and scoping confusing and the claims (where I can understand them) unjustified. I laud the broadness and ambition of the study, but find it ultimately somewhat unconvincing in its present form. Originally, I gave this a rating of 3, but on reflection, I think the inclusion of human data that qualitatively exhibits early-time OCS is on its own interesting enough to bump that up to a 5.
[1] Saxe, Andrew M., James L. McClelland, and Surya Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks." arXiv preprint arXiv:1312.6120 (2013).
[2] Belrose, Nora, et al. "Neural Networks Learn Statistics of Increasing Complexity." arXiv preprint arXiv:2402.04362 (2024).
[3] Kang, Katie, et al. "Deep neural networks tend to extrapolate predictably." arXiv preprint arXiv:2310.00873 (2023).
[4] Bordelon, Blake, Abdulkadir Canatar, and Cengiz Pehlevan. "Spectrum dependent learning curves in kernel regression and wide neural networks." International Conference on Machine Learning. PMLR, 2020.
[5] Simon, James B., et al. "The eigenlearning framework: A conservation law perspective on kernel regression and wide neural networks." arXiv preprint arXiv:2110.03922 (2021).
[6] Bordelon, Blake, Alexander Atanasov, and Cengiz Pehlevan. "How Feature Learning Can Improve Neural Scaling Laws." arXiv preprint arXiv:2409.17858 (2024).
[7] Abbe, Emmanuel, et al. "The staircase property: How hierarchical structure can guide deep learning." Advances in Neural Information Processing Systems 34 (2021): 26989-27002.
[8] Kumar, Tanishq, et al. "Grokking as the transition from lazy to rich training dynamics." arXiv preprint arXiv:2310.06110 (2023).
问题
- I do not understand the relationship between this work and a broad range of work [4, 5, 6] on spectral learning curves, where it is known that features are learned hierarchically in roughly the order of their power in the generative function. The original Saxe paper [1] demonstrates a similar result in a linear setting. How is the early-time learning of a constant solution related to this learning of eigenmodes in order of decreasing power? Is this a zeroth order term in some sort of a Harmonic expansion of the target function? I would like a clear and detailed mathematical exposition of this fact, either in the author response, or an update to the manuscript. I know, for instance, sometimes networks begin learning by attempting a kernel solution before transitioning to feature learning ("grokking", see [8]) -- viewed this way, are such settings examples of learning the "optimal linear solution" at early time?
- Are these results only true for hierarchical tasks? Again, the scope of where and when this is true is not clear. For example, the MNIST experiment (effect is weaker in Orth. MNIST) makes me think the results only hold in hierarchical settings (again, challenging "universality").
- I don't quite understand the claim around "generic input correlations" -- as I see it, you show that a range of settings where input data (independent of labels) has low-dimensional structure exhibit learning of OCS. But I thought the OCS was about mimicking the structure of labels independent of input, not the other way around? What does have to do with learning the optimal distribution over labels ( is the input matrix)?
… Section 6 makes the (extremely strong) claim that such reversion can equivalently arise from input correlations.
While it might appear like a strong claim we do believe that it is true that fairly typical correlations in the input data gives rise to early learning of the optimal constant solution. For nonlinear networks, we conduct experiments without bias terms in the architecture (Fig. 5 of the revised draft), so that the effect of data is clearly isolated in this case.
However, the initial submission did not contain a figure which examines a case in deep linear networks without bias term, but with input correlations which do exhibit the eigenvector. We have added a solvable case where these conditions hold true and have added it to the paper in Fig. 17 in Appendix A.11. In the bottom row of the figure, we can see that the first SVD mode is indeed exactly equivalent to the OCS mode, i.e. . The right panel highlights how the network is driven towards the OCS up until the time-point when the second effective singular value (which is quite close in time) is learned. We hope that the additional results in linear networks makes our claim about the role of input correlations in OCS learning more precise.
If it is indeed true that very general covariate structure can give rise to this behavior, then why have you spent most of the paper reasoning about the contrived and simple setting of linear networks with bias?
A bias towards the OCS is observed over the course of learning in human and artificial learners. Prior work by Saxe et al. (2013) has established linear networks as a powerful but theoretically transparent model to study such learning dynamics. Its power stems from the fact that it is agnostic about the specifics of the data: instead, it makes predictions about the learning dynamics solely in terms of the broad statistical features, in our case the leading eigenmode of the data correlation structure.
The inclusion of bias terms presents the minimal modification to observe OCS learning. Only after developing the theory for the simple case of bias terms in Section 3, it is revealed that its effect formally resembles a change in input correlation. This theoretical insight that the linear network approach enabled is then expanded on in the second part of the paper.
Phrases like "might hence be ubiquitous beyond these datasets" arouse suspicion. Also, how can I believe that early-time learning of OCS is "universal" when it vanishes in even linear networks if you ablate the biases?
We now realise how our wording can suggest that OCS learning will necessarily be observed in every learning setting, which we do not believe to be true. As you correctly point out, the controlled linear network setting exactly identifies such a counter-example: If both bias and input correlations are completely removed, our theory predicts that OCS learning indeed vanishes.
However, we believe that this is not the usual case in real-world learning settings. Instead, architectural bias and data correlations are mild assumptions. As far as we know, biases are a common way to shift data and the default in deep learning libraries.
As for data structure, only the statistics of the leading eigenmode are relevant for OCS learning. We present a theoretical argument in Apx. A.5.6 why a weak symmetry in the data will already satisfy this premise on the eigenmode. We provide evidence in Fig. 6 (of the revised paper) that such correlations are indeed approximately present in the image datasets we tested.
I put very little weight on experiments that work with MNIST and CIFAR (...)
We share the view that the core contribution of our work lies in developing a theoretical model of OCS learning. In the response above, we have argued that our theory shows that data correlation and architecture are robust factors to induce OCS learning, but strictly applies to linear networks only. To be considered a candidate to explain observations in nonlinear learners, we needed to test and ablate these factors also in the nonlinear setting. We agree that such evidence can only be suggestive, and have made this limitation more clear in the discussion section. We hence cannot conclude that OCS learning is inevitably present, as the premises we identify – although plausibly true in many cases – are not necessarily fulfilled.
While we do ablations for artificial learners, for humans we cannot claim what is the mechanism behind OCS learning, but only point to candidate factors. We had addressed this in our discussion: “
the ambiguity between architecture and data in driving the OCS does not allow us to determine the underlying mechanism in human learners.” Still, we aim to disentangle this question in future experiments via the systematic variation of input correlations between stimuli presented to human learners which could point to the factors we identified, or possibly rule them out altogether.
You remark that "characteristic learning signatures" arising from early time learning of the OCS involve a sort of step-wise learning structure. While I'm aware a variety of toy synthetic settings studied by theorists exhibit this behavior (see , for instance), no non-trivial data (ImageNet, C4) gives rise to dynamics that look like this (learning curves smoothly go to low loss rather than being step-wise). So certainly the word "universal" does not describe these types of learning curves that are "characteristic" of early-time learning of OCS.
We agree that our results do not imply that early, step-wise OCS learning will necessarily be observed in any dataset. While the step-wise structure of early bias learning is reported in some cases , it may not be visible in the loss curve in all cases: We observe this in the shallow linear setting in Fig. 2B, and also in our simulations with CelebA seen in Figure 15 in Appendix A.9. We were intrigued by your point and added the loss to the Figure 15 for completeness. There, although the loss curve is not stepwise, the output still is transiently driven towards the OCS, which is the core of our claim. While such evidence is suggestive of our claims being more broadly applicable, we still see the key contribution of our work to understand elementary components of learning, such as in the simple settings we consider.
Karpathy, A. (2019) A Recipe for Training Neural Networks.
Ye, H. J., Zhan, D. C., & Chao, W. L. (2021). Procrustean training for imbalanced deep learning. ICCV.
Questions
I do not understand the relationship between this work and a broad range of work on spectral learning curves (...)?
Indeed the linear network setting connects to the kernel learning literature, and we largely agree with your assessment. To detail:
A seminal work making this bridge is
, with the kernel function replaced by a linear NTK, i.e. our Prop. 3. When comparing with
, it is important to distinguish two types of early learning biases: A static (“NNGP”) setting where the bias incurs when few samples are present but the loss has converged (discussed in
), and those which bias occurs through the learning dynamics which Saxe et al.
considered, described by the NTK.
Still, there is an intuitive connection between the settings: In the dynamical case, the bias results from the kernel-dominant modes giving the largest reduction in error signal and therefore having the highest gradient and learning rate. Similarly, in the static case, using samples to estimate the kernel-dominant modes will lead to lowest expected error and therefore be favoured by the optimization
.
provides a formal discussion between NTK and NNGP.
Finally, feature learning enters this picture if the NTK is dynamic over learning: Our setup is situated in this more general case
. In contrast, lazy learning corresponds to an NTK that stays close to its initialization for the entire course of learning. We show that it is the initialization NTK in Prop. 3 (“in early training”) that incurs the bias, until changes in the NTK (corresponding to feature learning) will move the output to the full solution.
So indeed, OCS learning can be understood as a dominant mode in the early NTK, biassing the network towards the corresponding dominant output mode.
We are unsure where it would be most helpful and leave it at this overview for now because the mathematics are rather involved, but are very happy to provide more detail!
References
Atanasov, A., Bordelon, B. and Pehlevan, C., 2021. Neural networks as kernel learners: The silent alignment effect.
Kunin, D., Raventós, A., Dominé, C., Chen, F., Klindt, D., Saxe, A. and Ganguli, S., 2024. Get rich quick: exact solutions reveal how unbalanced initializations promote rapid feature learning.
Avidan, Y., Li, Q. and Sompolinsky, H., 2023. Connecting NTK and NNGP: A unified theoretical framework for neural network learning dynamics in the kernel regime.
Are these results only true for hierarchical tasks? (...).
The results do not only hold true in the hierarchical task. In linear networks we show cases of OCS learning on non-hierarchical tasks in Appendix A.10 for the case of class-imbalance and in the newly added Appendix A.11 we show OCS learning in the case of uniformly distributed labels under input correlations. For nonlinear networks we show several results for non-hierarchical data in Appendix A.9 of the paper where models are trained with different loss functions on tasks such as celebA face attribute detection and a class-imbalanced version of MNIST.
We also think that there could be a misunderstanding about the Orth. MNIST task. In the particular task we specifically removed any input correlations and ablated architectural bias terms to show the absence of full OCS learning. The experiment is intended to highlight how OCS learning depends specifically on these two ingredients.
I don't quite understand the claim around "generic input correlations" -- as I see it, you show that a range of settings where input data (independent of labels (...)?
Indeed, the OCS will just output the average label, independent of input details. Disregarding bias terms here for simplicity, comes into play in that its statistics need to be such that is its dominant eigenvector – a property which we show to be true in for several image datasets in Fig. 6 (revised PDF), hence “generic” input correlations are sufficient.
We hope these answers clarify the scope of our claims. Please let us know if you have any further questions!
I thank the authors for their response. I will keep my score.
Thank you for your response. In the meantime, please do not hesitate to let us know if there is anything else we can provide to address any remaining questions and concerns.
Thank you very much for your thorough review and the detailed and constructive feedback provided. We are encouraged by your assessment that our work spans a rare breadth from ML theory and experiments to human behaviour. We are also glad that you found the bridge to human experiments we develop to be a particular strength of our paper.
Weaknesses
The paper is quite confusing in terms of how it is laid out, and what the scope of which claims are. In my opinion, it tries do a bit too much. For instance, tensor trickery related to the NTK feels out of place/a low order concern when you didn't even take the time to define the hierarchical learning task clearly. I had to look through the Appendix as well as (Saxe, 2019) to get a sense of the task. Some specific weaknesses:
Thank you for this valuable information on the readability of our work. We now clearly realise that a reference to the previous literature alone is not enough for a self-contained presentation of the task and we have updated the PDF with a dedicated section on the hierarchical task. In an effort to improve linearity we have also now moved the discussion of the NTK to the Appendix.
It's not clear the scope of the claims being made. Either you are 1) remarking that the learning curves between linear, nonlinear networks, and humans on a certain class of task look similar but the mechanistic reasons are not well understood, and this is interesting and warrants further study or 2) making a claim that ALL neural networks are marked by this early-time behavior, and they are driven by mechanisms you identify. I would be more sympathetic to the contributions of this paper if you were making claim (1), but since you use the words "universal" and "mechanistic" in your abstract, I take it you are making claim (2). If this is the main claim, then the evidence presented is inadequate as justification. This is for the reasons that follow.
While we agree that we don’t have evidence to make claim (2) in a general sense, we do believe that our findings extend beyond (1) by manipulating the same factors which drive OCS learning across several artificial learners as well as providing general theoretical arguments. We now include an additional setting which directly demonstrates how OCS learning in linear networks can be purely driven by input correlations. In the answers below we detail these points. We have rephrased relevant sections in our paper to better reflect the precise nature of our claims.
This work employs an analytically tractable analysis of learning dynamics via linear networks, focusing on the bias towards the optimal constant solution (OCS) early in training. A 2-layer (and 1-layer) linear network with no bias terms is first studied, whose learning dynamics can be solved exactly (under a commutativity assumption). This analysis is then extended to include a bias term in the first layer, in the case of uncorrelated inputs, and it is shown that the same approach works in this case (since the commutativity condition remains satisfied). Empirical evidence for early convergence to the OCS is then provided, which does not occur for linear networks without bias terms, and a theoretical explanation is provided, showing that the OCS mode dominates the input-output covariance matrix when biases are present. This phenomena is then related to human learning, showing that humans also display a bias towards the OCS early in learning. Finally, it is demonstrated that early bias towards the OCS is not solely driven by bias terms, and can instead be induced by properties of the data via the input-input correlation matrix.
优点
Though simplicity bias has been observed empirically and is thought to be important for generalization, its origin lacks a theoretical understanding. This paper is significant in providing a concrete theoretical explanation for a particular kind of simplicity bias (OCS learning)
Solves for parameter dynamics under a relaxed assumption (Proposition 1) compared to previous works (which typically assume uncorrelated inputs)
Demonstrates that linear networks can provide meaningful predictions of non-linear models, including CNNs
The relation to human learning is interesting and unique
缺点
It would be helpful to include a brief description of the hierarchical learning task in addition to Fig 1 for those unfamiliar with the setup
问题
Does the theoretical argument nicely extend to linear networks with more than two layers? In the multi-layer case, would a bias term present in an intermediate layer result in a bias towards OCS?
In what ways do the linear network solutions extend Saxe et al. (2014) precisely? Through the inclusion of bias terms, and a relaxed assumption (Proposition 1)?
We thank you for your thoughtful comments and constructive feedback. We are delighted that you found our work a significant contribution to our understanding of particular types of simplicity biases. We are also encouraged by the fact that you found our connections to human learning and nonlinear models interesting and unique.
Weaknesses
It would be helpful to include a brief description of the hierarchical learning task in addition to Fig 1 for those unfamiliar with the setup
Thank you for pointing this out. We agree with you on this point and have added a clear exposition of the hierarchical learning task to the main text of the paper that makes explicit reference to Figure 1. We hope that this change will increase the presentation of the paper.
Questions
Does the theoretical argument nicely extend to linear networks with more than two layers? In the multi-layer case, would a bias term present in an intermediate layer result in a bias towards OCS?
While the two-layer case captures the essential factors for OCS learning, our analysis generalises to the case of multiple layers in a notationally involved, but conceptually straightforward manner. The starting point to see this is the NTK in Proposition 3, which readily generalises to the multi-layer case . The general pattern of Proposition 3 extends: Bias terms in all layers affect OCS learning. The closer to the output a bias is located in the architecture, the more immediate it will affect the output statistics. Adding more layers will attenuate the effect of bias terms in earlier layers towards the output through downscaling by the initialization scale of deeper layers (the terms \\sigma\_{\\mathbf{W}^\\ell}^2 in Proposition 3).
Shi, J., Shea-Brown, E. and Buice, M., 2022. Learning dynamics of deep linear networks with multiple pathways. Advances in neural information processing systems, 35, pp.34064-34076.
In what ways do the linear network solutions extend Saxe et al. (2014) precisely? Through the inclusion of bias terms, and a relaxed assumption (Proposition 1)?
We here discuss how we differ from Saxe et al. (2014) in particular. For a discussion of our general contribution, please refer to the global response above.
Indeed, we needed to generalise their assumptions to make progress on our results. Saxe et al. had mostly considered whitened inputs, as have several subsequent works like Braun et al. (2022), Kunin et al. (2024), or Shi et al. (2022). Proposition 1 answers this question and is a simple condition that allows us to derive learning dynamics for several interesting datasets which contain input correlations.
Second, to the best of our knowledge, we for the first time extend predictions made by the formalism directly to human learners, where prior work had conjectured similarities . We were able to directly demonstrate some behavioural similarities in terms of OCS learning. We argue in Appendix A.5.6 that this is likely due to symmetries in the data that make the theory’s premises approximately apply in many settings.
References
Rogers, T.T. and McClelland, J.L., 2005. A parallel distributed processing approach to semantic cognition: Applications to conceptual development. In Building object categories in developmental time (pp. 353-406). Psychology Press.
Saxe, A.M., McClelland, J.L. and Ganguli, S., 2019. A mathematical theory of semantic development in deep neural networks. Proceedings of the National Academy of Sciences, 116(23), pp.11537-11546.
Dear reviewer ogfv, as the discussion period ends quite soon, please let us know if you have any other questions!
Thank you for addressing my concerns and questions. I will maintain my score
Thank you again for your review and for maintaining your positive assessment of our work!
The paper shows that in the early phases of training neural networks learn the “optimal constant solution” (OCS), i.e., they mirror the statistics of the labels while being invariant to the input points. It considers deep linear networks with bias terms and derives the exact learning dynamics based on Saxe et al. (2014; 2019). It then proves that deep linear networks learn the OCS early in training. To support its findings, the paper performs experiments on both deep non-linear neural networks and human learners that display qualitative agreement with the theoretical analysis.
优点
-
The study contributes to the broader understanding of simplicity biases in machine learning, specifically in the early learning phases.
-
The paper’s exact characterization of the learning dynamics of deep linear networks with bias terms is a novel contribution. This analysis is theoretically sound and provides new insights on how simplicity emerges during the initial training phase.
-
The theoretical findings are supported by learning experiments involving both artificial neural networks and human learners, bridging machine learning and neuroscience.
缺点
-
The most significant weakness of this work is the lack of a definition and clear description of the studied “hierarchical task”. While the authors describe it as a "hierarchical category learning task" with hierarchical levels, there’s no clear explanation of the precise structure or production rules, such as the nature of the relationships between input categories and output labels. Without knowledge of the distribution of , , and precise definitions of the hierarchical levels, I find it practically impossible to fully understand the study, particularly the empirical aspects. The link between the model of data and the OCS is also not discussed.
-
The continuous correct-rejection scores, which are crucially used to infer OCS learning, are not defined within the main text.
-
The figures and the corresponding captions are poor and challenging to interpret. Figure 1 contains 6 panels that lack sufficient standalone explanations. In Figure 2, the reader does not know what different colors represent.
-
Proposition 3 must assume some initialization statistics for the weights that are not specified.
-
In general, the paper presentation is dense, often lacking detailed explanations or explicit definitions. It also frequently refers the reader to the appendix, detracting from readability.
In summary, the paper has potential, offering some interesting insights into early learning in neural networks and connections to human cognition. However, I think it requires significant rewriting and, in particular, improvements in clarity, structure, presentation, and detailed explanation and discussion of the empirical results.
Typos:
(L253) The word “singular” is repeated twice.
问题
See comments above.
We thank you for your thoughtful comments and feedback. We are glad that you found our work to be a useful and theoretically sound contribution to our understanding of simplicity biases in initial training phases. Furthermore, we are delighted that our approach which bridges human experiments, ML experiments, and theory was received as a strength of our work.
Weaknesses
The most significant weakness of this work is the lack of a definition and clear description of the studied “hierarchical task”. While the authors describe it as a "hierarchical category learning task" with hierarchical levels, there’s no clear explanation of the precise structure or production rules, such as the nature of the relationships between input categories and output labels. Without knowledge of the distribution of X,Y , and precise definitions of the hierarchical levels, I find it practically impossible to fully understand the study, particularly the empirical aspects.
The hierarchical task requires learning a mapping from one-hot, input vectors to output vectors that are depicted in Fig. 1B. Hereby each output vector is “three-hot”, i.e. the vector has three entries/labels. The hierarchical structure arises from the similarity between output vectors where some labels are more general and correspond to more than one input , while labels corresponding to the bottom of the hierarchy are specific to a single input vector . The task is motivated in the literature on semantic cognition and leverages the fact that semantic information is usually hierarchically structured .
We now realise that a reference to the literature alone as in our initial submission is insufficient for a self-contained presentation of our approach. We have revised the paper and have added a dedicated section that introduces the task.
The link between the model of data and the OCS is also not discussed.
Our work identifies two complementary factors driving OCS learning: Bias terms in the architecture, and a specific property of correlations in the input data. The latter factor describes the interaction between data structure and OCS learning. We argue in Appendix A.5.6 on theoretical grounds that the necessary correlation property (the presence of the eigenvector) requires mild symmetry that is satisfied for the hierarchical task in particular, but extends to a more general class of datasets. This argument is empirically supported by Figures 5 and 6 (in the revised PDF).
The continuous correct-rejection scores, which are crucially used to infer OCS learning, are not defined within the main text.
Thank you for highlighting this. We have now moved the definition of the metric and the associated intuition from the appendix into the main text under Section 3 Setup, which should help to increase the interpretability of results.
The figures and the corresponding captions are poor and challenging to interpret. Figure 1 contains 6 panels that lack sufficient standalone explanations. In Figure 2, the reader does not know what different colors represent.
We have updated all figure captions with the intention of making them self-contained. In particular, we have updated the caption of Figure 1 with more extensive explanations in our revised paper, we hope that the revision is sufficiently clear for standalone interpretation.
The different colours in Figure 2 represent the asymptotic magnitude of the effective singular values and for the deep and shallow linear networks respectively. Darker colours represent higher asymptotic values. We have addressed this concern in a revised draft that contains this description in the figure caption.
Proposition 3 must assume some initialization statistics for the weights that are not specified.
Thank you, we overlooked this point. We now state the initialisation statistics as an assumption in the revised paper.
General presentation
In general, the paper presentation is dense, often lacking detailed explanations or explicit definitions. It also frequently refers the reader to the appendix, detracting from readability.
To reduce the overall density of the paper we have moved the subsection containing the NTK analysis from the main text of the paper. Instead our updated draft includes detailed and explicit definitions for our main learning task and metrics as well as more detailed figure captions which we hope will aid readability.
We aimed to provide extensive technical detail to our claims while complying with the space constraints of ICLR. To make references at the places where they are relevant, we resorted to pointing to the Appendix while keeping abbreviated statements in the main text. We acknowledge that this makes the text dense, but found it to be a frequently chosen compromise for the venue. However, in an attempt to improve clarity of the main text we have expanded on key explanations of task, metrics, and figures as described.
In summary, the paper has potential, offering some interesting insights into early learning in neural networks and connections to human cognition. However, I think it requires significant rewriting and, in particular, improvements in clarity, structure, presentation, and detailed explanation and discussion of the empirical results.
We are glad that you find the paper to have potential and that you found the connections between early simplicity (OCS) biases and their connections to human cognition interesting. We appreciate that the paper lacked clarity in figure captions that hindered the clear interpretation of our results. We have made an effort to improve the clarity of captions, we have also defined the continuous-correct-rejection scores in the main paper. Please let us know if there are further steps we can take to increase the clarity and interpretability of results.
Typos: (L253) The word “singular” is repeated twice.
Thank you for the careful read! We removed the typo in the revised PDF.
Timothy T. Rogers and James L. McClelland. Semantic Cognition: A Parallel Distributed Processing
Approach. MIT Press, 2004
Andrew M. Saxe, James L. McClelland, and Surya Ganguli. A mathematical theory of semantic
development in deep neural networks, 2019
Dear Authors,
Thank you for the responses. Could you clarify whether you plan to update the submission? The current pdf appears to be the original one without all the mentioned modifications.
Best,
Reviewer VeTy
Thanks for your rebuttal. I have reviewed the work again, along with the responses from the authors. My initial concerns have been addressed. As a result, I have raised my score and now recommend acceptance.
As a final note, the references appear to have been reordered, which makes them difficult to find. I encourage the authors to adhere to the guidelines and arrange the references alphabetically by author.
Dear reviewer VeTy,
We are glad that your concerns are addressed and we thank you for your generous increase in score! We are very grateful for your time spend reviewing and believe that your comments and feedback have greatly enhanced the quality of our paper. We will make sure to order the references alphabetically in the final version.
Best wishes,
The authors
Dear VeTy
Many apologies. It appears that our last upload did not go through. Thank you for pointing this out. The new pdf with all modifications should now be up.
Best wishes,
The authors
Thank you again for your feedback. We would like to emphasize that the author-reviewer discussion period ends soon. Based on your comments we have made substantial revisions to the paper and have provided a detailed response. Hence, we kindly request you to please consider a reevaluation of your score if your concerns are addressed. Please let us know if you have any additional questions or concerns!
The paper theoretically and empirically examines early learning of the optimal constant solution (OCS) in neural network training. Under a hierarchical label setting, the authors analyze the learning dynamics of linear networks with bias terms, demonstrating that OCS learning happens in the early training stages. Empirical results and human studies confirm that this phenomenon extends beyond linear networks to more complex nonlinear networks and human learners.
优点
The paper introduces an intriguing exploration of simplicity biases of optimal constant solution in training neural networks, with potential relevance to human societal patterns. This work could interest both the machine learning community and researchers in other fields, with possible extensions to topics like fairness and stereotyping.
缺点
My major concern about the paper is about its clarity:
- Plots are not explained well, e.g., what does "output activation top" represent in Figure 3?
- Even after reading Appendix A.4, I'm still a bit confused about the definition of TNR rate, in equation 7, the level index appears on the left side of the equation but not on the right, should the term in right hand side be to denote the target label for a specific level? Besides, this definition seems important and should potentially be moved to the main paper.
- The hierarchical MNIST dataset is only explained in the Appendix, and it's not quite clear how the hierarchical label structure is established for the dataset.
问题
- Does the results in the paper indicate that bias towards OCS solution could eventually vanish after long enough training?
- Fig. 3 and Fig. 5 seem to reflect different outcomes, where Fig. 3 shows a initial score drop for all hierarchical levels, only the top level score drop is prominent in Fig. 5, what would be the cause of this? Besides, do TNR and refer to the same metric? If so, it's better to maintain consistent notations.
- Line 897, "Given responses in y or in and target vectors", a typo?
- Due to the hierarchical label design used in the paper, it seems well related to the domain of multi-label learning, I would suggest to discuss some related work in this domain.
Thank you for your review and thoughtful comments. We are delighted that you thought of our results as intriguing and we are encouraged by your assessment that our work is of interest for researchers beyond the machine learning community.
Weaknesses
My major concern about the paper is about its clarity:
- Plots are not explained well, e.g., what does "output activation top" represent in Figure 3?
Thank you for highlighting this point. We understand that our figure captions were not sufficiently self-contained and we have now revised all figure captions to include more detail. For the particular case of Figure 3, “Output activation top” denotes the activation of a single output unit of the model corresponding to the top level of the output hierarchy. We plot the response to all inputs . While we do explain the figure in the main text in Section 4.1, Indifference, we have clarified this in the figure caption.
- Even after reading Appendix A.4, I'm still a bit confused about the definition of TNR rate, in equation 7, the level index appears on the left side of the equation but not on the right, should the term in right hand side be to denote the target label for a specific level? Besides, this definition seems important and should potentially be moved to the main paper.
We can see where the confusion comes from. The metric is slightly unusual but was chosen as it allowed us to directly compare network and human responses for signatures of OCS learning. Key is the mapping from level to the relevant start and end indices in the output vector. We have made this mapping more clear in a revised version of the metric. Following your suggestion, we have now also moved this metric to the main text of the paper.
- The hierarchical MNIST dataset is only explained in the Appendix, and it's not quite clear how the hierarchical label structure is established for the dataset.
We develop our theory starting from a task which considers perfect within-class similarity, i.e. each output label is only associated to a single (one-hot) input. For the experiments with humans and CNNs, we consider the more realistic case where there is natural variation in each input class.
For the "hierarchical MNIST" task we used the ten digit classes provided by MNIST and then sampled 8 classes randomly. For each image in each class, we then replaced the default one-hot label corresponding to each class with the corresponding hierarchical, “three-hot” label seen in Fig. 1B. E.g., all images corresponding to MNIST digit “1” might be assigned some randomly chosen “three-hot” output vector . We added this description to the revision of the manuscript in Section 5, Setup.
Questions
Does the results in the paper indicate that bias towards OCS solution could eventually vanish after long enough training?
Your question addresses an important subtlety of our work. To answer it precisely, it is helpful to distinguish between OCS mode as part of the network SVD in equation 2 and the biassing of outputs towards the OCS. In linear networks, the OCS mode is learned early training and persists thereafter as part of the network function. However, it is indeed the case that the OCS output is transient. The reason for this is that the OCS mode is a necessary part of the final correct network output, but not sufficient: Other modes still need to be acquired during subsequent learning. In consequence, the OCS mode is increasingly complemented by other modes, keeping its strength but shifting the OCS output from the biassed towards the optimal solution.
In nonlinear networks we cannot make such precise statements. However, work by Kang et al. has demonstrated that complex, nonlinear models trained on various datasets do revert to the OCS when generalising out of distribution and that network biases play an imported role in this behaviour. Prior results thus suggest to us that in deep nonlinear networks, the OCS similarly remains a component of the network function even when the model has acquired the desired input-output mapping and fits the training data.
Kang, Katie, et al. "Deep neural networks tend to extrapolate predictably." arXiv preprint arXiv:2310.00873 (2023).
Fig. 3 and Fig. 5 seem to reflect different outcomes, where Fig. 3 shows a initial score drop for all hierarchical levels, only the top level score drop is prominent in Fig. 5, what would be the cause of this? Besides, do TNR and f_tn refer to the same metric? If so, it's better to maintain consistent notations.
We appreciate that the difference between Fig. 3 and Fig. 5 is not entirely clear. Models in both figures are trained on the hierarchical task seen in Fig. 1. The difference between Fig. 3 and Fig. 5 stems from practical constraints that arise when comparing models with humans: Humans delivered discrete responses via button clicks while networks produced continuous responses. To compare these learners, we were required to discretize neural network responses by treating outputs as logits for a distribution from which we sampled three discrete responses. This way we obtained “discretized” network outputs. This step gives rise to the observed differences between Fig. 3 and Fig. 5. As noted correctly, sensitivity to the OCS in this setting is mostly evident in an initial tendency to overselect labels associated with the top level of the output hierarchy in regardless of the provided input stimulus.
Please note that in the revised version of the manuscript Figure 5 is now Figure 4.
Besides, do TNR and f_tn refer to the same metric? If so, it's better to maintain consistent notations.
The metrics are indeed computed equivalently for both figures. As noted above the only difference is that metrics in Figure 3 are computed on continuous responses in while metrics in Figure 5 are computed for discrete responses in . Thus, the metric in Figure 5 is akin to classical true negative rates computed in classification tasks, which led us to choose different names on the figures. We now see how this could cause confusion and we have updated the PDF to use the notation “true negative rate ” throughout.
Line 897, "Given responses in y or in and target vectors", a typo?
Thanks for the thorough read! We have fixed the typo in the revised version of the paper.
Due to the hierarchical label design used in the paper, it seems well related to the domain of multi-label learning, I would suggest to discuss some related work in this domain.
Thank you for highlighting this connection. We agree that the hierarchical learning task bears resemblance to multi-label learning. In these settings label imbalance is also a frequent concern. We have added a brief discussion of multi-label learning and how it might relate to the hierarchical task in a new Appendix A.12.
We would like to emphasis that the discussion period ends quite soon. We have made major revisions based on your comments and provided detailed responses to your questions. Hence, we kindly request you to please check the revisions and rebuttal and consider adjusting your score if your concerns are addressed. Please let us know if you have any additional questions!
Thanks to the authors for the response and apologize for my late reply.
The updated manuscript and the authors' thorough response have addressed my concerns regarding clarity, and I remain positive about the paper
Thank you again for your review and for reaffirming your positive assessment of our paper! If you have any additional questions do not hesitate to send us a message.
Global response
We would like to thank all reviewers for their valuable comments, and the time and effort spent reviewing our submission. We were delighted and encouraged by the many positive aspects of our work pointed out by the reviewers. Reviewers highlighted the interdisciplinary approach taken in the paper spanning theory, human experiments, and ML experiments (Rev. VeTy, ogfv, oCMD). Furthermore, reviewers emphasised that our work is a significant contribution to our understanding of simplicity biases in early learning (Rev. ogfv, S1SK, VeTy). Also, several reviewers found our experiments with human participants a particularly valuable and unique point of the work (Rev. oCMD, ogfv).
Clarity and presentation
Despite these positive aspects some reviewers had concerns about the clarity and presentation of the paper. We thank the reviewers for highlighting these and have responded to all these points in individual responses. We have taken the following steps to mitigate these concerns and to enhance the clarity and presentations of our paper.
-
Incomplete explanation of the hierarchical task (Rev. VeTy, ogfv, oCMD)
We previously thought that Figure 1 in conjunction with reference to previous literature might be sufficient to explain the hierarchical task to the reader. We now realise that a simple self-contained section is needed to fully appreciate the learning problem. We have added such an explanation under the heading The hierarchical task in Section 3 in the updated paper to reflect this shared concern. -
Definition of the True Negative Rate/Correct-rejection scores in the main text of the paper (Rev. VeTy, S1SK).
Firstly, we appreciate that the metric should have been defined in the main text of the paper as it is crucially used to infer OCS learning in models and human participants and we are very sympathetic to this concern. We have now added a definition of the metric alongside intuitions to Setup in Section 4 of the main paper to ease interpretation of results and overall clarity.Reviewer S1SK raised that “True Negative Rate” and appear to refer to the same metric while not having consistent notation. We initially chose the different names as metrics for human learners and networks in Figure 5 were computed on discrete responses in while metrics in Figure 3 are computed for continuous responses in , although the mathematical expressions agree. For discrete responses the metric more closely resembles a classic true negative rate which led us to choose different notations. We have updated the draft to adopt consistent notation throughout the manuscript.
-
Figure captions (Rev. VeTy, S1SK)
Two reviewers highlighted that they found our figure captions not sufficiently self-contained. We have updated all figure captions in the revised draft to contain sufficient detail that should allow for standalone interpretation. -
Denseness of presentation (Rev. VeTy, oCMD)
It was pointed out to us that our presentation was at times dense and hard to follow, implying that our goal of writing a comprehensive main text has resulted in compromised clarity. To arrive at a better structure, we moved more auxiliary results to the appendix, including the NTK analysis. As we still believe that these results provide important insight, we keep references in the main text.
Scope of claims (Rev. oCMD)
Reviewer oCMD in particular expressed concerns about the scope and support of our claims as a primary weakness of our work. We appreciate that being precise here is crucial and would like to delineate the exact scope of our claims. We have revised our language throughout in the updated PDF to clarify the exact scope of our claims.
What we do not claim:
We do not claim that OCS learning will inevitably be observed regardless of possible architectures and datasets. For example, in the controlled setting of linear networks, we identify cases in which OCS learning will not occur: linear networks will not display OCS learning in the absence of input correlations and architectural bias terms. In nonlinear networks we demonstrate the same phenomenon empirically, i.e. the absence of OCS learning in absence of bias terms and input correlations.
What we do claim:
-
Linear networks
We do claim that OCS learning in linear networks emerges via two distinct factors entering the output dynamics in the same way:
a. Bias terms in the model architecture
b. Particular input correlations, i.e. whenever contains a constant eigenvector.
To clearly highlight point b we have now isolated a setting with solvable learning dynamics in which only input correlations drive OCS learning and added it Appendix A.11 of the revised paper. -
Nonlinear networks
We observe that OCS learning occurs in nonlinear networks, and that the same factors as in linear networks causally control OCS learning in the settings we consider (see Figure 5 and Figure 6 of the revised paper). Two findings support these observations: First, the results from Kang et al. show that for various network architectures and datasets, networks revert to the OCS in generalisation settings. This reversion implies that the OCS is a component of the learned network function.
Second, our theory does not depend on details of the dataset, rather, it only relies on a very coarse property of the dataset correlation statistics, its leading eigenmode. We argue in Appendix A.5.6 that if mild symmetries are present, this will hold true for many datasets.
Together, this empirical and theoretical evidence suggests that OCS learning and the factors behind it are promising candidates to explain observations beyond linear networks. -
Human learners
For human learners, we state in our discussion that we cannot make any specific mechanistic claim beyond the empirical observation that they do display OCS learning. However, we think that future experiments which leverage the recording of neural data or the targeted manipulation of input correlations in experimental stimuli could disentangle between the driving factors we identified, or possibly rule them out altogether.
References
Kang, Katie, et al. Deep neural networks tend to extrapolate predictably. arXiv preprint arXiv:2310.00873 (2023).
This paper studies early learning of the optimal constant solution (OCS) in neural network training. Under a hierarchical label setting, the paper analyzes the learning dynamics of one- and two-layer linear networks with biases, demonstrating that OCS learning happens in the early training stages. Empirical results and human studies confirm that this phenomenon extends beyond linear networks to more complex nonlinear networks and human learners.
This paper studies a very important question about the simplicity biases of neural network training and identifies an interesting phenomenon. However, it's missing citations and discussion of some highly relevant work on the spectral bias of NNs [1-3]. In particular, [1] shows the early learning of the optimal linear function (1st-order); the optimal constant function studied in the current paper is the 0th-order version of that result. Although these results are not the same, the high-level arguments and intuitions are similar. The current paper's main argument is essentially that the constant all-one vector has a significant spectrum in the input covariance matrix. The fact that this implies early learning of OCS follows directly from arguments made in previous literature and is not novel. The authors are encouraged to discuss these connections in the next version and to make it clear what this paper's novel contributions are in the context of existing literature.
[1] Hu, W, Xiao, L, Adlam, B, Pennington, J. The Surprising Simplicity of the Early-Time Learning Dynamics of Neural Networks. NeurIPS 2020.
[2] Cao, Y, Fang, Z, Wu, Y, Zhou, D, Gu, Q. Towards Understanding the Spectral Bias of Deep Learning. IJCAI 2021.
[3] https://en.wikipedia.org/wiki/Frequency_principle/spectral_bias
审稿人讨论附加意见
The reviewers' main initial concerns were around the lack of clear definitions of some important concepts in the paper, such as hierarchical task, True Negative Rate/Correct-rejection scores, etc. The authors improved the presentation of these definitions in the rebuttal period.
Reject