PaperHub
6.8
/10
Poster4 位审稿人
最低2最高5标准差1.3
2
5
5
5
3.0
置信度
创新性3.3
质量2.8
清晰度2.5
重要性3.0
NeurIPS 2025

Composing Global Solutions to Reasoning Tasks via Algebraic Objects in Neural Nets

OpenReviewPDF
提交: 2025-05-05更新: 2025-10-29
TL;DR

Semi-ring structure exists in 2-layer neural nets for reasoning tasks on Abelian group (e.g., modular addition), trained with L2 loss, which enables constructing global solutions analytically from non-optimal ones instead of gradient descent.

摘要

We prove rich algebraic structures of the solution space for 2-layer neural networks with quadratic activation and $L_2$ loss, trained on reasoning tasks in Abelian group (e.g., modular addition). Such a rich structure enables analytical construction of global optimal solutions from partial solutions that only satisfy part of the loss, despite its high nonlinearity. We coin the framework as \ours{} (\underline{Composing \underline{G}lobal \underline{S}olutions}). Specifically, we show that the weight space over different numbers of hidden nodes of the 2-layer network is equipped with a semi-ring algebraic structure, and the loss function to be optimized consists of sum potentials, which are ring homomorphisms, allowing partial solutions to be composed into global ones by ring addition and multiplication. Our experiments show that around $95%$ of the solutions obtained by gradient descent match exactly our theoretical constructions. Although the global solutions constructed only required a small number of hidden nodes, our analysis on gradient dynamics shows that overparameterization asymptotically decouples training dynamics and is beneficial. We further show that training dynamics favors simpler solutions under weight decay, and thus high-order global solutions such as perfect memorization are unfavorable. The code is open sourced\footnote{https://github.com/facebookresearch/luckmatters/tree/yuandong3/ssl/real-dataset}.
关键词
landscape analysismodular addition; gradient dynamics; reasoning; symmetry; representation learning

评审与讨论

审稿意见
2

This work aims to study how neural networks with one set of hidden neurons learns to perform computations on Abelian groups - with modular addition being a notable instantiation of the problem which has received recent interest in the community. The primary result of this work shows that global solutions can be obtained by composing local solutions which, in the case of neural networks, means that portions of the network's hidden layers will each perform meaningful parts of a larger computation which results in the globally optimal solution. Experiments on modular addition show that the vast majority of neural networks converge to these composable solutions rather than learning monolithic solutions, corresponding to memorisation.

优缺点分析

Strengths

Quality and Significance

The experimental results seem to support the broad points of the theory and I agree that it is important to the community that we begin to understand how neural networks solve more algorithm tasks. The fact that modular addition has been a recent task of interest to the mechanistic interpretability community speaks to this works potential significance. Additionally, the generality of the work in considering tasks on an abelian group further increases its applicability.

Weaknesses

Clarity

In general I find this work to be extremely unclear. Grammatically the text itself is difficult to read and contains long sentences which end abruptly. The abstract for example has the following sentence: "Specifically, we show that the weight space over different numbers of hidden nodes of the 2-layer network is equipped with a semi-ring algebraic structure, and the loss function to be optimized consists of sum potentials, which are ring homomorphisms, allowing partial solutions to be composed into global ones by ring addition and multiplication". Unfortunately this problem persists throughout the writing. Secondly, almost no effort is made to limit the jargon which is very specific to the group theory tasks being considered. For example in the introduction it says: As a result, our theoretical framework, named CoGS (i.e., Composing Global Solutions), successfully constructs two distinct types of Fourier-based global solutions of per-frequency order 4 (or “2 ×2”) and order 6 (or “2 ×3”), and a global solution of order d2 that correspond to perfect memorization". At this point in the text the notion of "per-frequency order" had not been discussed at all. This is particularly problematic in this case as it is the sentence highlighting the contribution of the proposed CoGS framework. A few other examples in the early paragraphs of points I found vague (to try highlight my point and make things more actionable):

  • Lines 67 and 68: "Existing works on expressibility (Li et al., 2024; Liu et al., 2022) gives explicit weight construction of neural networks weights (e.g., Transformers)". I have no idea what it means to give a weights construction of neural network weights.
  • Lines 110 and 111: "We exclude ϕ01\phi_0 ≡ 1 because the constant bias term has been filtered out by the top-down gradient from the loss function". What does it mean for the bias term to be filtered out by the gradient?
  • Figure 1: While Figure 1 is helpful to try contextualise everything (and I found myself returning to this figure often), the numbers in the caption are not present in the image.

Considering the clarity of the mathematics. Notation is not always introduce well and there a typos which make the already dense notation even more taxing. For example on Lines 109 and 110 ϕk\phi_{-k} is used to be the complex conjugate of ϕk\phi_k but the notation for complex conjugates had already been introduced. Another example is on Line 101, where P1P_1^\perp is introduced with little justification for its need. Also, Definition 5 seems to try to define concatenation twice but using different notation each time. I also find it quite unhelpful to refer to notation as nouns in text, especially when the notation is given no explicit intuition to help ground the variable. The worst example of this is the "rr terms" which are defined mathematically on Line 115 and then referenced as entities thereafter. This serves a particular disservice to Definition 1. Lastly, the absence of proof sketches (or even references in the main text to proves in the appendix) makes it very difficult to understand the contributions of the work - especially when coupled with the complex notation and needless definitions in some cases. For example, it is surprising that Definition 2 and 3 were deemed necessary but then the Kernel of rr was left as a footnote and without formal definition at all (and it seems necessary to connect the geometry of the weights to the loss, no?). Another example here is that Sum Potentials are referenced four times before Definition 6 and after Theorem 1 where it is necessary to appreciate the point of the theorem (as is stated on Lines 45 and 46).

Quality

I have three main concerns for quality. The first two are quick to summarise: firstly I do not feel the assumptions of the theory or work are clearly stated. This could be a consequence of the other clarity issues, but I find it surprising that no assumptions seem to be mentioned (especially when prior work had such strong assumptions such as infinite width). Secondly, there is little acknowledge of the limitations of this work. The problem setup and architecture is very specific and the use of one-hot encodings is a big limitation (other works in this space allow for an initial embedding layer even). In addition, the experiments show that gradient descent follows the theoretical construction in most cases, but this is on one instantiation of a reasoning task on an Abelian group. There is no guarantee any other task would similarly follow the theory but the work seems to see the modular arithmetic example as conclusive. This leads me to my final concern, I think the claims around the theory matching experiments is too imprecise and the in some cases the text over claims how close this fit is. For example the abstract says the experiments fit the theory 95% of the time, but then Lines 246 and 247 actually make it clear that "Although characterizing the full gradient dynamics is beyond the scope of this paper, we theoretically characterize some rough behaviors below". This is then followed by Lines 277 to 282 which claim exact matches between the theoretical construction and empirical solutions. None of those three statements seems to be consistent to me.

I hope my comments here at least serve to constructively guide the exposition of the paper. While the work itself seems promising, the paper as it stands feels rushed.

问题

It is difficult to phrase my concerns as questions, but I would be happy to continue the discussion on my above points during the rebuttals phase. It would help if the authors could clarify any mistakes I may have made in my comments above, and answer any of the smaller questions mixed in there.

局限性

I don't feel the limitations of this work were accurately stated and I have raised this as a weakness for quality above.

最终评判理由

I have considered the authors rebuttal closely and reread the paper with the additional information and explanations in mind. I remain sceptical that the paper is ready for publication at this time. I do think the research itself is interesting and addressing a timely setting which is of present interest to the community. However, the presentation of the work and lack of formal statements means it is difficult to determine the full extent of the authors claims and to assess the rigour of the work. These are important considerations and to me are dealbreakers regardless of how much I like the work itself.

格式问题

None

作者回复

Thanks the reviewer for the comments! We are glad to hear that the reviewer acknowledges that our work has addressed important problems that the community cares (modular addition and more broadly Abelian group), and come up with experiments that align with the theoretical construction.

In general we agree that the notation can be improved to make them easier to understand. We are addressing the detailed questions below.

I have no idea what it means to give a weights construction of neural network weights

We want to make sure the reviewer knows that there is a large body of literature on network expressibility, which aims to explicitly construct weights for the network so that they can accomplish certain tasks. E.g., expressibility of Transformers [4][5], CoTs [7][8], Graph Neural Networks[9], general Deep Neural Networks [10], etc. Here we just list a few literatures. The goal is to find the upper bound of capability for a certain network architecture. Whether the training procedure may achieve the constructed solution remains an orthogonal question.

References:

[4] https://aclanthology.org/2024.tacl-1.30.pdf

[5] https://www.jmlr.org/papers/volume22/20-302/20-302.pdf

[6] https://arxiv.org/abs/1912.10077

[7] https://openreview.net/pdf?id=NjNGlPh8Wh

[8] https://arxiv.org/abs/2402.12875

[9] https://arxiv.org/abs/1810.00826

[10] https://arxiv.org/abs/1606.05336

What does it mean for the bias term to be filtered out by the gradient?

Since we are optimizing a projected 2\ell_2 loss P1(o/2dy)2\|P^\top_1 (\mathbf{o}/2d - \mathbf{y})\|^2, the output o\mathbf{o} is zero-meaned (i.e., projected by P1P^\top_1) before sending to loss. So if we shift o\mathbf{o} by a constant vector, then it won’t change the loss.

On the other hand, the backpropagated gradient also contains the P1P^\top_1 term, which makes the gradient zero-mean. So if there is any bias term in the weights, such a bias term will neither affect the output nor change over time. Therefore, we can just ignore them when finding global solutions (that’s also why we have k0k\neq 0 for all the Fourier expansions of the weights (Eqn.2)).

the numbers in the caption are not present in the image

Thanks for the suggestions. We will update it in our next revision to make the story more clear.

“I find it surprising that no assumptions seem to be mentioned (especially when prior work had such strong assumptions such as infinite width)”

That’s one of the core contributions of this work. We don’t assume infinite width in our analysis (and do not leverage the existing framework such as Neural Tangent Kernel or mean field) and theoretically characterize the fine-grained algebraic structure of the global solutions in the setting we have specified. To our best knowledge, such a structure hasn’t been studied before in the literature (other reviewers also agree upon it).

the use of one-hot encodings is a big limitation (other works in this space allow for an initial embedding layer even

Using one-hot encoding is for mathematical clarity and simplicity. We can always extend this analysis to any fixed orthogonal embedding matrix, which corresponds to a global rotation of all the weight matrices. Many peer-reviewed top-conference papers also leverage such assumptions (e.g., reverse curse [11], Joma [12], etc) to facilitate theoretical study. For embedding matrices whose embeddings are not orthogonal or updates during training, we leave it for future work.

[11] https://arxiv.org/abs/2405.04669 (NeurIPS'24)

[12] https://arxiv.org/abs/2310.00535 (ICLR'24)

For example the abstract says the experiments fit the theory 95% of the time, but then Lines 246 and 247 actually make it clear that "Although characterizing the full gradient dynamics is beyond the scope of this paper, we theoretically characterize some rough behaviors below"

We want to make it clear that there are two different (and often orthogonal) aspects of theoretical study, (1) studying the solution structures (i.e., what solutions we get after training) and (2) studying the gradient dynamics (i.e., how we get there). Our paper is about (1), which itself is a nontrivial topic, and we leave (2) as the future work. For (1), our work characterizes the algebraic structures of the final solutions, by leveraging the property of the loss function.

For example on Lines 109 and 110 is used to be the complex conjugate of but the notation for complex conjugates had already been introduced

This is for notation convenience. E.g., we can write kckϕk\sum_k c_k \phi_k rather than separating it into two summations, k=1d/2ckϕk\sum_{k=1}^{\lfloor d/2 \rfloor} c_k \phi_k and its complex conjugate counterpart.

Definition 5 seems to try to define concatenation twice but using different notation each time.

We kindly disagree. Def. 5 defines the addition and multiplication operations of the weights, which are the key operations for the semi-ring structure that support the entire algebraic framework. The concatenation (as the addition) and the Kronecker’s product (as the multiplication) are two very different operations and cannot be regarded as “defining two concatenations twice”.

… but this is on one instantiation of a reasoning task on an Abelian group. There is no guarantee any other task would similarly follow the theory but the work seems to see the modular arithmetic example as conclusive.

Theory always has its assumption and range of applications. We never claim that we solve all reasoning tasks on Abelian groups, but a representative example like modular additions, and never regard the modular arithmetic problem as conclusive for the reasoning tasks.

For example, it is surprising that Definition 2 and 3 were deemed necessary but then the Kernel of was left as a footnote and without formal definition at all

The paper is math heavy and we don’t want to introduce additional formal definitions in the main text. On the other hand, we want to give more math details for people who have relevant mathematical backgrounds and are interested.

审稿意见
5

This paper analyzes the algebraic structure of the loss minimum for a two-layer network with quadratic activations trained to do modular arithmetic. They find the weight space has the structure of a semi-ring and the loss can be decomposed in terms of ring homomorphisms. Using this ring structure, the authors construct explicit weight values which minimize the loss and show empirically that gradient descent usually converges to these solutions, favoring lower order (fewer non-zero neurons) solutions. The authors also prove mode connectivity implying gradient flow will favor lower order solutions under L2 regularization.

优缺点分析

Strengths

  • The theoretical analysis in this paper is inventive, rigorous, and insightful. Despite the simple setting, the authors are able to explicitly describe the minimum.
  • The algebraic structure they describe for the weight space is very interesting on its own, but it also provides insight to gradient dynamics.
  • The experiments show that GD actually converges to these minima most of the time.
  • The decomposition of the loss into SPs and the leveraging of the ring structure to create an algorithm to produce solutions is clever. Moreover, from my (incomplete) understanding of Thm 6, this decomposition is relevant to finding the full solution via gradient descent as well since as q grows, the SPs are "decoupled."
  • I found Thm 5 interesting since it provides a theoretical backing for mode connectivity (in this simple setting) and theoretical justification for preference for simpler more generalizable solutions. There is minimal theoretical progress in this area and this work provides an interesting and different perspective.
  • The limitations are clear and the future work is nicely described. I particularly appreciate that point that its useful to put different widths in the same framework. I think the idea to develop new training algorithms based on this seems like a longshot, given that we don't know how this generalizes, but I appreciate the idea.

Weaknesses

  • The current work is limited to a fairly toy problem (modular arithmetic / abelian group composition) and a small network (2-layer) with an uncommon and simple activation \sigma(x)=x^2. The authors acknowledge this limitation and suggest taylor approximations can be used to generalize their work to other activations. While I am noting this limitation and the fact that the insights here may not scale to larger networks or other problems, I still think these choices make for a reasonable trade-off. The paper is able to reach very insightful conclusions in this restricted setting which makes these restrictions reasonable.
  • The notation and indexing could be a big cleaner. Overall, this is not so easy to solve since there is a lot to keep track of, but I wasn't a huge fan of using {a,b,c} as index values. For example w_{aj} is a 2-dim tensor, but the group index is actually suppressed here and "a" refers to the layer but as a layer name, not an index. The index value for layer is p. \mathbb{I} is not defined, but I can infer its the characteristic function. Also, in Thm 6 J is not defined.
  • There is some informal terminology would I appreciate if it were made more precise. For example, L234 "mixed", L243 "solution solution". The argmax on L190, it took me a bit to understand the arg is considered \mathbf{u}, so maybe clarify that. L174: "once we reach 0/1-sets" -- Can you spell out precisely what this means? I believe I understand L193-L194 but it could be clearer.
  • The related works section on algebraic structures in ML is pretty short. Are there no other instances?

问题

  • Can you help me understand Thm 6? What is J and what does it mean for the dynamics of the SPs to be decoupled?
  • In some sense, cross-entropy loss would make more sense for the task. I assume L2 loss was chosen to make the analysis work?
  • Is it possible to say how much of the minimum Lemma 1 describes?
  • For Fig. 2 why do the weights for z_a z_b and z_c look so similar?

局限性

Yes.

最终评判理由

The rebuttal answered my questions. Per my comments, I do think clarity and precision can be slightly improved. I keep my positive score.

格式问题

No concerns.

作者回复

We thank the reviewer for their insightful and encouraging comments. We are glad to hear that "the theoretical analysis in this paper is inventive, rigorous, and insightful". Thanks! Here are the answers to the questions.

The related works section on algebraic structures in ML is pretty short. Are there no other instances?

There are relatively few papers on algebraic structures in ML. We list a few more below and will add them to extended literature sections in the appendix. Most of the existing literatures focus on injecting algebraic structures (e.g., invariants) into the features and/or neural networks to improve its performance, while our paper focuses on the emerging algebraic structures after training a neural networks on structured data. So overall the research direction is very different.

[1] 1990, On the algebraic structure of feedforward network weight spaces

This paper focuses on “universal” invariant transformation (and associated group) of the network weights that leads to identical output for any given inputs. This includes permutations of nodes, scale up one layer while scale down other layers, etc, which is universal to any input/output pairs. In contrast, our formulation focuses on algebraic structures induced by a family of specific (yet important) problems. Such specificity (e.g. the task to be learned has group structure) leads to rich structures in the resulting weights. Our formulation also automatically removes the invariants studied by the paper, by identifying networks with permuted hidden nodes and by normalizing the coefficients z, and focus on algebraic structure beyond [1].

[2] 2021, Algebraic neural networks: Stability to deformations

This paper constructs a novel network architecture called AlgNN, by adding algebraic layers which apply group operation to the intermediate representation of the neural network to achieve certain properties. AlgNN is mostly used as a framework to unite multiple existing data processing pipelines (e.g. convnet, discrete time signal processing, graph/group neural network, etc). While our paper studies the structure of global solutions and performs experiments to verify the theoretical findings, [3] does not cover these topics.

[3] 2024, Unitary convolutions for learning on graphs and groups

This paper imposes a set of constraints on the network weight (unitary group convolution) to enable stable training (and avoid over-smoothing problems) for graph neural networks. It does not try to characterize the structure of global solutions led by the gradient descent, but more focuses on improvement of empirical performance.

[4] 2024, A Galois theorem for machine learning: Functions on symmetric matrices and point clouds via lightweight invariant features

This paper focuses on constructing invariant features on matrices subject to joint permutation of rows and columns using Galois theorem, and uses such features for point clouds classification. The paper also shows how many such features are needed in order to yield a good classification. Overall it does not concern neural networks and its internal mechanism and representations, which is the topic of our paper.

Can you help me understand Thm 6? What is J and what does it mean for the dynamics of the SPs to be decoupled? Please check Appendix E for the definition of Jacobian matrix JJ, which is defined as the partial derivative between a collection of sum potentials (SPs) and the weights. Theorem 7 means that if we have infinite width and if the loss function is simplified, then the dynamics of each sum potential r is decoupled, that is, the dynamics of r_1 and r_2 are independent and does not interfere with each other, at least at (or near) the initialization.

In some sense, cross-entropy loss would make more sense for the task. I assume L2 loss was chosen to make the analysis work?

Yes. Here is an analysis when the loss function is cross-entropy. Note that in Theorem 1, we use projected 2\ell_2 loss (i.e., P1(yo/2d)2\|P_1^\perp (\mathbf{y} - \mathbf{o} / 2d)\|^2), and cross entropy can be written in a similar form as well. Please check Lemma B.2 in [3] for details. Roughly speaking, the gradient of cross entropy loss log[exp(oy)/1exp(o)]-\log\left[\exp(\mathbf{o}^\top \mathbf{y}) / \mathbf{1}^\top \exp(\mathbf{o})\right] is the same as the gradient of the following surrogate loss P1(yγo/d)2|P_1^\perp (\mathbf{y} - \gamma \mathbf{o} / d)|^2, where γ=(1+oP1o/2d+o(oP1o/2d))1\gamma = (1 + \mathbf{o}^\top P_1^\perp \mathbf{o} / 2d + o(\mathbf{o}^\top P_1^\perp \mathbf{o} / 2d) )^{-1} depends on the zero-mean norm of o\mathbf{o}, the unnormalized logits before softmax. When this norm is much smaller than dd, then γ1\gamma \approx 1, then the exact sample analysis can be applied to cross entropy loss, except for some changes of the constant coefficients in the loss decomposition (Eqn. 3).

In practice, the unnormalized logits o\mathbf{o} may grow unbounded in magnitude. This makes γ\gamma shrink to a very small number, and thus the first loss term k0rkkk\sum_{k\neq 0} r_{kkk} in Lemma 1 will dominate. If we only consider this term in finding our global solutions, the found solution would be much simpler and of lower order. Experiments also support that.

For Fig. 2 why do the weights for z_a z_b and z_c look so similar?

They are so similar since they all converge into Fourier bases with the same frequency, and thus their magnitudes peaked at the same frequency in the Fourier domain. Their phrases are different though (which is not plotted in the figure).

Is it possible to say how much of the minimum Lemma 1 describes?

We are not sure what the reviewer means. Lemma 1 characterizes a sufficient condition to achieve global minimal (i.e., the loss function becomes zero, which is the smallest number of a 2\ell_2 loss could achieve).

评论

Thank you for answering my questions. Per my comments, I do think clarity can be slightly improved. I keep my positive score.

评论

Dear reviewer,

We will make a revision to improve the clarity a bit and fix typos (e.g., "solution solution") in the next revision. Thanks for your encouragement!

Best, Authors.

审稿意见
5

This paper introduces CoGS (Composing Global Solutions), a theoretical framework that reveals a rich algebraic structure in the weight space of 2-layer neural networks with quadratic activation trained on group multiplication task. By showing that the weights form a commutative semi-ring and that the L2 loss decomposes into sum potentials, CoGS enables global solutions to be analytically composed from partial ones. The authors construct low-order Fourier-based and perfect memorization solutions and demonstrate that approximately 95% of solutions found by gradient descent match these theoretical forms.

优缺点分析

Strength

  • The paper introduces a rigorous algebraic framework (CoGS) that reveals a previously unrecognized semi-ring structure in the weight space of 2-layer neural networks and interprets the L2 loss through ring homomorphisms. This perspective provides clear, sufficient conditions for identifying global optima and offers a systematic way to construct them analytically.
  • The theoretical predictions are validated by experiments showing that about 95% of gradient descent solutions exactly match the proposed constructions, with minimal reconstruction error. This high degree of empirical alignment reinforces the practical relevance of the theoretical insights.
  • The paper provides good explanations for why gradient descent with weight decay prefers low-order, structured solutions over high-order ones like perfect memorization. It links this preference to topological and algebraic connectivity between solutions of different orders.

Weaknesses

  • The theoretical framework is developed specifically for 2-layer networks with quadratic activations trained on group operations (e.g., modular addition), assuming full knowledge of the data's algebraic structure. This limits its applicability to real-world settings where such structure is unknown or more complex.
  • While CoGS successfully constructs a family of global solutions, it does not cover all possible global optima.

问题

Overall, I find this to be a strong theoretical paper with a meaningful contribution to the algebraic aspects of neural network learning. However, there are a few points I would like the authors to clarify:

  • The empirical results focus primarily on synthetic group multiplication tasks. To what extent do the authors believe the key findings could generalize to real-world settings where the underlying structure is implicit or unknown? It would strengthen the work to either discuss this more explicitly or explore empirical validations beyond group-structured tasks.
  • The paper briefly touches on grokking in the conclusion. Given that this is a well-known behavior that usually occurs during training, can the authors comment more specifically on how the CoGS framework can be extended to explain this phenomenon?

局限性

yes

最终评判理由

I recommend acceptance for this paper due to its theoretical contribution which is supported by the empirical evident.

格式问题

I do not notice any major formatting issues.

作者回复

We thank the reviewers for the insightful and encouraging feedback! We are glad to hear that the reviewer thinks that the paper is rigorous theoretically, able to identify previously unknown structures, performs thorough experiments that align well with the theory, and provides good explanation. We address the questions below:

Application to real-world settings

First we want to clarify that in our current setting, we don’t need to know the underlying specific group structure (e.g., for order-6 group, whether it is Z6Z_6 or Z2Z3Z_2 \otimes Z_3) before the learning, and after learning, solutions of specified structure emerge according to the theory, given that it is an Abelian group.

Furthermore, in the appendix (Appendix F), we have extended this work to group action prediction, which concerns a state space X\mathcal{X} and group GG that acts on it, to yield a different state in the state space. This setting covers more practical scenarios such as learning of a transition function in RL that takes the state and action pair (s,a)(s,a) as the input and returns the next state ss’.

Going beyond, in order to deal with real-world data distributions (e.g., mixture of group structures, missing data, noise, etc), extended analysis is needed. For example, instead of using quadratic activations, we may need to consider truncated quadratic activations, or squared ReLU (e.g., σ(x)=max2(x,0)\sigma(x) = \max^2(x, 0), which is empirically used in large model training (check [1])), and study whether different neurons can capture various local structures, by masking out irrelevant structures and noise automatically. Also, self-attention is needed when the text input contains compositional and heterogeneous structures (e.g., a sequence “The modular addition of 2 and 4 mod 5 is 1” rather than “2 4 1”). The intuition is that self-attention will force the model to focus on important parts of the previous input, and omit others. Ideally, the model focuses on “2” and “4” when predicting the answer “1”, and it focuses on “modular addition” when reading “mod 5”. How the learning picks such a mechanism remains a mystery.

Grokking

Our analysis gives some intuition regarding grokking. Theorem 1 shows that the loss function is a summation of a linear term (rkkkr_{kkk}) and a few quadratic (sum of squares) terms in the Fourier domain. When the weights are small, the quadratic terms are much smaller than the linear term and the weights grow at a uniform pace. This means that all the weights are similar in magnitude (in the Fourier domain) and memorization happens (check the perfect memorization solution in Eqn. 7 in Corollary 3, in which all weights in Fourier domains have the same magnitude). However, when the weight magnitude becomes larger, the quadratic terms (as well as weight decay) catch up, which leads to specialization of hidden neurons into different frequencies, which is the generalization solutions (order-4 and order-6 solutions in Corollary 2 and 5).

From this analysis, it is clear that we need a small learning rate to demonstrate the entire phase transition process, and a fairly large weight decay to trigger node specialization, converging to low-order solutions, as suggested in Theorem 6. This simple analysis seems to align with existing studies [2], i.e. small learning rate and reasonably large weight decay lead to grokking, and the model stays in memorization with super small weight decay (e.g., Fig. 7(b) and 8(b) in [2]).

Note that this is a very rough qualitative analysis and lots of questions remain, e.g. percentage of training samples out of all possible distinct pairs to enable such transition, etc. The current framework will lead to additional terms of Theorem 1, if the training distribution is no longer uniform across all input pairs. This makes analysis complicated. Therefore, we leave it for future work.

Not covering all possible global optima

We agree that we do not cover all possible global optima in this paper, and there may exist solutions that do not satisfy the sufficient condition shown in Lemma 1. Actually we have constructed it (Corollary 5, F4/6F4/6). Nevertheless, the fact that the gradient descent solutions fall into the constructed solutions shows that these constructions are useful and may connect to the gradient dynamics that governs the training process.

References

[1] Noam Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.

[2] Towards Understanding Grokking: An Effective Theory of Representation Learning (https://arxiv.org/abs/2205.10343)

评论

I thank the authors for their detailed response. I maintain my support for accepting this paper, as it offers a meaningful theoretical contribution.

评论

Dear reviewer,

Thanks for your acknowledgement for our contribution!

Best, Authors.

审稿意见
5

This work introduces CoGS (Composing Global Solutions), a theoretical framework that identifies and characterizes an algebraic structure that exists over global solutions to group multiplication tasks with 2-layer neural networks. The authors identify a sufficient condition for global solutions, construct examples of global solutions that do (Z_F6) and do not (Z_F4/6) satisfy the sufficient conditions, and demonstrate how partial solutions can be combined to form global solutions. The work supports its theoretical framework with empirical experiments, finding that up to 95% of neural networks trained with gradient descent can be decomposed into solutions that match the provided constructions. The authors provide a theoretical statement of Occam's Razor in the CoGS setting, proving that gradient dynamics prefer simpler solutions that have lower order under the proposed algebraic structure.

优缺点分析

Strengths:

  • The work is original. The idea of identifying algebraic structure over neural network solutions is novel, as far as I am aware. Furthermore, the work is related to an important problem in the field, understanding grokking and how NNs solve arithmetic tasks.
  • Theoretical analysis is novel and comprehensive. The authors provide sufficient conditions for global solutions, propose constructed solutions, and analyze gradient dynamics.
  • Empirical results are comprehensive and convincing. The fact that the authors identify compositions of their constructions within neural networks trained with gradient descent strongly supports their proposed framework.

Weaknesses:

  • The work focuses on 2-layer neural networks with quadratic activations, L2 loss, and abelian group multiplication tasks. This setting is already interesting, but the work would be strengthened if the framework were extendable to further settings.

问题

  • Lemma 1 provides sufficient conditions for global optimality, and the authors identify global solutions that do not satisfy them (e.g. Z_F4/6). Does the CoGS framework characterize the full set of global solutions, or are there also classes of zero-loss solutions that exist outside the algebraic framework?
  • Relatedly, were the authors able to identify what solutions the neural networks were converging to in the 2-5% of cases for which gradient descent did not converge to a composition of your constructions?
  • How generalizable are the work's results/framework to changes in activation (e.g. ReLU) or loss (e.g. cross-entropy)?

局限性

Yes

最终评判理由

After reading the response from the authors and the discussion from the other reviewers, I am more confident in my assessment about the work and maintain my score accordingly.

格式问题

N/A

作者回复

Thanks the reviewer for the insightful and encouraging comments! We are glad to hear that the reviewer thinks the work is original, is related to important problems in the field, and with comprehensive and convincing experiments. We reply the questions below.

Does the CoGS framework characterize the full set of global solutions, or are there also classes of zero-loss solutions that exist outside the algebraic framework?

No it doesn’t characterize the full set of global solutions. F4/6 (Corollary 5) is an example that does not satisfy the sufficient condition but still remains global optimal.

Relatedly, were the authors able to identify what solutions the neural networks were converging to in the 2-5% of cases for which gradient descent did not converge to a composition of your constructions?

During our empirical study, these solutions are solutions that fit “almost” into our construction, but are ruled out due to the fact that they didn’t pass the pre-defined numerical thresholds when determining which composition the solution should belong to. If we train longer, they may fall into the threshold. We haven’t seen solutions that are fundamentally different from our theoretical construction yet.

How generalizable are the work's results/framework to changes in activation (e.g. ReLU) or loss (e.g. cross-entropy)?

For cross-entropy, note that in Theorem 1, we use projected 2\ell_2 loss (i.e., P1(yo/2d)2\|P_1^\perp (\mathbf{y} - \mathbf{o} / 2d)\|^2), and cross entropy can be written in a similar form as well. Please check Lemma B.2 in [3] for details. Roughly speaking, the gradient of cross entropy loss log[exp(oy)/1exp(o)]-\log\left[\exp(\mathbf{o}^\top \mathbf{y}) / \mathbf{1}^\top \exp(\mathbf{o})\right] is the same as the gradient of the following surrogate loss P1(yγo/d)2|P_1^\perp (\mathbf{y} - \gamma \mathbf{o} / d)|^2, where γ=(1+oP1o/2d+o(oP1o/2d))1\gamma = (1 + \mathbf{o}^\top P_1^\perp \mathbf{o} / 2d + o(\mathbf{o}^\top P_1^\perp \mathbf{o} / 2d) )^{-1} depends on the zero-mean norm of o\mathbf{o}, the unnormalized logits before softmax. When this norm is much smaller than dd, then γ1\gamma \approx 1, then the exact sample analysis can be applied to cross entropy loss, except for some changes of the constant coefficients in the loss decomposition (Eqn. 3).

In practice, the unnormalized logits o\mathbf{o} may grow unbounded in magnitude. This makes γ\gamma shrink to a very small number, and thus the first loss term k0rkkk\sum_{k\neq 0} r_{kkk} in Lemma 1 will dominate. If we only consider this term in finding our global solutions, the found solution would be much simpler and of lower order. Experiments also support that.

For other activations, we can do a Taylor expansion and analyze the high order teams that would appear in Theorem 1 (and quadratic activation is just a special case). Unfortunately, it will make the terms quite complicated. So there should be a better way of doing it, which we leave for future work.

[3] J. Zhao, GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection, ICML’25

评论

Thanks to the authors for the response! The fact that "We haven’t seen solutions that are fundamentally different from our theoretical construction yet" is very interesting and seems to suggest your analysis is even more complete than you imply in the work. I believe including a discussion of this point, potentially including examples, in the appendix would strengthen the paper. I will maintain my score.

评论

Dear reviewer,

Great to hear that! We will include a discussion and a few examples for it in the next revision. Thanks!

Best, Authors

最终决定

This work presents a theoretical approach that uncovers algebraic properties within the weight parameters of two-layer neural networks using quadratic activations when trained on group multiplication problems. Specifically, it shows that these weight exhibit the mathematical structure of a commutative semi-ring, while the L2 loss function can be broken down in terms of ring homomorphisms. The authors provide some theoretical results that further allow an analytical construction of global solutions by combining partial solution components, and also establish a mode connectivity result.

The empirical studies/simulations in this work show that gradient-based optimization with L2 regularization tends to converge toward solutions of lower order solution, which tends to support the theoretical results.

During the internal discussion, a couple of reviewers championed this paper, and it is therefore recommended for acceptance.

Nevertheless, I encourage the authors to consider the feedback from Reviewer S1jd (and others as well) when revising the paper, in order to make it accessible to a broader community. Here is a partial list:

  1. Two reviewers who gave positive reviews were not sure about the significance and implications of Theorem 6. One of them stated that they didn't fully understand Theorem 6.
  2. During the internal discussion, Reviewer S1jd wrote "The authors in the rebuttals did not seem inclined to add assumption statements or other more formal ways of presenting the limits of their theory and so I do not feel inclined to adjust my scores."