Learning Hierarchical Polynomials with Three-Layer Neural Networks
We show that three-layer neural networks learn hierarchical polynomials of the form $h = g \circ p$, where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial, in $\widetilde O(d^k)$ samples.
摘要
评审与讨论
This paper studies the three-layer neural network for feature learning trained by two-stage layer-by-layer gradient descent. In the first stage, the network training is equivalent to a random feature model where is fixed. In the second stage, it’s also a random feature model. But it can be regarded as a composition of kernel methods, and achieves improved sample complexity than a single kernel methods.
优点
I appreciate the authors’ discussion in Section 5. It addresses my several concerns.
- In the algorithm, it uses multiple steps in the first stage to learn the feature and obtain better improved results when compared to (Nichani et al. 2023) on the sample complexity.
- I like the discussion on Sec. 5.2 and the question is quite interesting: why the compositional kernel methods can be efficiently learned than that with a single kernel? Intuitively, this question can be answered in some points but is difficult from the theoretical side.
缺点
-
The model and problem setting are a bit over-claim. Though it’s a three-layer neural network, only and are trained. This is actually a two-layer neural networks when the input is not the original data but after a fixed feature mapping. More importantly, the used layer-by-layer training scheme makes such three-layer neural network degenerate to a composition of two kernel methods (random features). With weight decay, the loss function under two stages is strongly convex and smooth. Both problem setting and algorithm are totally far away from true three-layer neural networks or even two-layer neural networks.
-
The results cannot handle the over-parameterized regime, i.e., the results require the width . The derived results set and depending on , and and should have different orders of . Otherwise the bound will become vacuous. This is because there is one term in the convergence rate, deriving by Rademacher complexity for the random features model.
问题
- How does the first inequality in Eq. (13) hold?
Overall, the model is a composition of kernel methods (or random features). In my view, understanding composition of kernel methods than a single kernel is more accurate and has more significance than so-called “three-layer” feature learning. I'm willing to increase my score if the significant revision on the motivation and story is done.
Thank you to the reviewer for your detailed comments. We address specific concerns of yours below.
“The model and problem setting are a bit over-claim…Both problem setting and algorithm are totally far away from true three-layer neural networks or even two-layer neural networks.”
- Please refer to the global response for discussion on the architecture/algorithm.
- We’d like to emphasize a few additional points about our work. First, even though each stage of the algorithm is a kernel method, the overall algorithm is not a kernel method, and crucially can learn hierarchical polynomials of the form with a much improved sample complexity. A major insight of our paper is thus that the sample complexity of learning is the same as learning , by first fitting the low-degree terms of (which is approximately , as proven in Lemma 2). Furthermore, this algorithm can be implemented by a three-layer neural network, demonstrating a mechanism by which three-layer networks can perform feature learning. These insights, along with the observation and also empirical evidence that a three-layer neural network can indeed implement this layerwise training procedure, has not appeared before in the literature and we believe is an interesting and important contribution.
- Next, we note that guarantees for feature learning in three-layer neural networks are few and far between, and all such results we are aware of do make some modification to the algorithm/architecture. We still find our ResNet architecture and layerwise training algorithm to be realistic while demonstrating a general feature learning principle. Our new experiments demonstrate that when all the parameters are trained jointly, the network still performs feature learning to learn the hierarchical target.
“The results cannot handle the over-parameterized regime, i.e., the results require the width .”
- We have updated our main theorem to handle arbitrary . Since the first stage of the algorithm is a kernel method, the generalization bound should not depend on the width, and instead on the RKHS norm of the target. In the random feature setting, the RKHS norm is just the norm of . Our initial proof was a bit loose, but we have refined our analysis so that by choosing the correct weight decay parameter , the sample complexity now scales with the RKHS norm of , independent of the width .
“How does the first inequality in Eq. (13) hold?”
- This inequality term should actually be an equality. This is because for , the two terms are independent because of the independence between and , and furthermore mean zero, since . Therefore, the cross terms are equal to zero.
We are happy to answer any more questions you may have, and we hope that if your concerns have been addressed you would consider raising your score.
This paper studies the problem of learning hierarchical polynomials with neural networks. In particular, it focuses on the problem of learning functions of the form , where is a degree polynomial and is a degree polynomial (the special case recovers the well-studied family of single-index models).
The paper's main result is that for a subclass of degree polynomials and standard Gaussian marginals, a 3-layer NN trained via layerwise GD on the loss learns the target hierarchical polynomial (realizable setting) with roughly samples and runtime that is polynomial in the parameters of the problem.
The main conceptual message is that this sample complexity guarantee improves over kernel-based approaches which require roughly samples and essentially cannot make use of the special hierarchical structure of the target functions.
优点
The paper's main result is that for a subclass of degree polynomials and standard Gaussian marginals, a 3-layer NN trained via layerwise GD on the loss learns the target hierarchical polynomial (realizable setting) with roughly samples and runtime that is polynomial in the parameters of the problem.
I think that the result is interesting and fits well with the ICLR community. In general, the paper is easy to read: the assumptions are presented in a clear manner, comparison with prior work is well-established and the contribution seems important in the field.
At a technical level, the authors essentially show that (i) during the first training phase, the NN implements kernel regression and essentially learns the underlying polynomial and (ii) during the second stage, the NN recovers the link function . The main technical tool is that using the special structure assumed for the polynomials (Assumption 4), the paper provides an approximate version of Stein's Lemma (see Lemma 2), which can be used to show Item (i) from the above technical results. I think that the paper's technical contribution is sufficient for acceptance, since it extends existing ideas and provides new insights in the area.
In general, I enjoyed reading this paper and I vote for acceptance.
缺点
I believe that it would be beneficial if the authors mentioned families of polynomials not captured by Assumption 4. This would make more clear how strong and restrictive this assumption is. This assumption highly simplifies the analysis and, hence, it would be nice if the authors could further discuss on this assumption (I see why the families of Remark 3 satisfy this condition, but I think a further discussion on how this assumption simplifies the analysis would be helpful).
问题
(1) I think a discussion on Assumption 4 (as I mention in the above section) is an important aspect that should be expanded in the current draft.
(2) As an additional comment, I think that Theorem 1 could be written more formally (e.g., mention that Assumptions 1-6 hold, mention the runtime of the training process, etc.).
(3) The current result heavily relies on Gaussian marginals. Do you think similar results could be established for other continuous measures? Or discrete ones (in the Boolean domain, any function is essentially a polynomial in the Fourier basis)?
Thank you to the reviewer for your detailed comments and for the positive evaluation of our paper. We address specific comments of yours below:
“I believe that it would be beneficial if the authors mentioned families of polynomials not captured by Assumption 4… it would be nice if the authors could further discuss on this assumption”
- Here are a couple examples of polynomials which violate the assumption. One is . Even though can be written as a sum of orthogonal components, it does not lie in the span of degree 3 Hermites (as there exists a linear term). Another example is where .
- Assumption 4 simplifies the analysis, in particular the proof of Lemma 2, since orthogonality of the implies independence of each term and thus makes it easier to compute the projection. In general, we expect that by the universality principle general features will be close to Gaussian and thus a similar result will hold, but this is challenging to make rigorous. We will be sure to add discussion on this point to a revision of our paper.
“Do you think similar results could be established for other continuous measures?”
- We believe the high level insights should be the same; however, we note that in the literature essentially all feature learning results require strong assumptions on the data distribution such as being Gaussian, uniform on the hypersphere, or uniform on the Boolean hypercube. We expect our results to translate to these simple distributions, but for more general distributions this question is still open for much simpler models such as single-index models (see for example (Bruna et al., 2023) for preliminary work in this direction)
Joan Bruna, Loucas Pillaud-Vivien, Aaron Zweig. On Single Index Models beyond Gaussian Data. NeurIPS 2023. arxiv.org/abs/2307.15804
The authors examine the problem of learning hierarchical polynomials of Gaussian data with three-layer networks trained using layerwise gradient descent. They compare the performance of three-layer networks to kernel approaches and two-layer networks, demonstrating a clear sample-complexity improvement without requiring any low-rank structure, as is typically done with sparse multi-index targets learned efficiently by two-layer neural networks. The theoretical analysis is reinforced by extensive discussion.
优点
The paper is pleasant to read and the mathematical results are supported with extensive discussion.
缺点
The main weakness of the paper is the close relationship with previous works. Although a fair comparison is given, the works by [Allen-Zhu & Li, 2019;2020] and [Nichani et al. 2023] contain many of the key ideas in the manuscript.
问题
-
It would help the unexperienced reader to include a small comment on [Ben Arous et al. 2021] in the related works section to give a complete overview on learning single-index models.
-
In remark 2 is written "can be extended to any ". What is the original condition on L?
-
Please state after Theorem 1 or Corollary 1 that there is a clear sample-complexity improvement for three-layer networks over two-layer networks. At the moment, this is done only with respect to kernel methods.
-
Could the author comment with references characterizing the failure of two-layer networks in learning these high-rank target functions? Are there any provable guarantees?
-
Before Lemma 1 mention that is used, and repeat references that prove for what reason you fit the best degree k polynomial.
-
Could the author comment on the requirements for the gradient step size in the first phase? It is mentioned that [Nichani et al. 2023] considered one large gradient step in this phase. How does this relate to your learning rate requirement in Thm 1?
Thank you to the reviewer for your detailed comments. We address specific concerns of yours below.
Please refer to the global response for discussion on the algorithm/architecture, and our new revision.
“The main weakness of the paper is the close relationship with previous works. Although a fair comparison is given, the works by [Allen-Zhu & Li, 2019;2020] and [Nichani et al. 2023] contain many of the key ideas in the manuscript.”
We would like to emphasize that the major original insight of our paper is that the sample complexity of learning is the same as learning , by first fitting the low-degree terms of (which is approximately , as proven in Lemma 2). Furthermore, this algorithm can be implemented by a three-layer neural network, demonstrating a new mechanism by which three-layer networks can perform feature learning. We believe that this insight is both novel and interesting.
To compare to the specific prior works:
- (Allen-Zhu & Li; 2019; 2020) considers a more restricted class of neural network, in which all activations are quadratic. Furthermore, they consider target functions of the form , where . As such, the learned feature already achieves vanishing test loss. Our setting is much more challenging, in that the learner must extract while only having access to ; the learner must also fit in order to obtain vanishing test error. We emphasize here that our choice of link function is very general, while Allen-Zhu & Li; 2020 can only handle which are almost linear. Our insight that a three-layer neural network can indeed implement both of these steps while only having access to is indeed a novel contribution compared to this prior work.
- Likewise, (Nichani et al., 2023) can only handle quadratic features, and for that setting obtains a suboptimal sample complexity (). Our key contribution is that this general principle of hierarchical feature learning -- the sample complexity of learning via a three-layer neural network is the same as learning -- holds for a broad class of hierarchical functions. A major technical innovation of our work is an extension of the “approximate Stein’s lemma” argument to handle degree polynomials.
To address the questions:
- “What is the original condition on L?” It’s .
- “Could the author comment with references characterizing the failure of two-layer networks in learning these high-rank target functions?” Existing guarantees for two-layer neural networks fall in two rough categories. One set of results uses NTK style arguments; while these can handle high-rank target functions, their sample complexity is no better than a kernel method. On the other hand, feature learning guarantees like [Damian et al., 2022; Abbe et al., 2022; 2023] which do improve over kernel methods require the target to be low rank. While no general lower bound against two-layer networks learning high rank functions exists, approximation lower bounds have been proven against specific high-rank targets: see for instance [Daniely, 2017; Safran et al. 2019; Safran & Lee, 2022; Nichani et al., 2023]. We will update the exposition to make this point more clear.
- “Could the author comment on the requirements for the gradient step size in the first phase?” We need the learning rate to be small enough to mimic gradient flow -- , where is the smoothness of the loss. This is a different mechanism to [Nichani et al., 2023], which requires a very large gradient step that scales inversely to the initialization size. We believe our algorithm better captures practice, where smaller gradient steps to mimic gradient flow are more common. This distinction is another point of novelty of our work as compared to [Nichani et al., 2023].
- We will be sure to add additional discussion on [Ben Arous et al., 2021], compare to two-layer networks after Corollary 1, and mention the requirement before Lemma 1.
We are happy to answer any more questions you may have, and we hope that if your concerns have been addressed you would consider raising your score.
Amit Daniely. Depth separation for neural networks. COLT, 2017. Itay Safran, Ronen Eldan, and Ohad Shamir. Depth separations in neural networks: What is actually being separated? COLT, 2019.
In this paper, the authors consider learning target functions of the form where is a degree polynomial and is a degree polynomial. They consider a specific 3 layer neural network with a skip connection and a bottleneck layer of size 1. They show, using layer-wise training that samples are sufficient to train the second layer weights and samples are sufficient to train the last layer weights. In particular, the total sample complexity is much smaller than required for inner-product kernels. Note also that information theoretically, is information theoretically minimal to learn degree- polynomials, however they only consider a particular subset of polynomials that has much smaller functional space dimension (except for ).
优点
- The paper considers the problem of understanding the benefits of feature learning for multi-layer neural networks. The one layer case has attracted a lot of attention, while multi-layer is so far way less understood.
- They are able to show a large separation with kernel methods. The reason is that the first layer is able to extract a good representation of the data (the polynomial ) from only seeing samples .
- This class of target functions naturally generalizes the single index model (), which is the natural class of functions for one-hidden layer neural networks (for one-hidden layer , the network has direct access to linear functions of the data).
- The paper is easy to follow and the assumptions and proof techniques are accurately presented and discussed.
缺点
- The architecture and algorithm are chosen specifically to succeed for this specific class of target functions (composition of degree- multivariate polynomials with univariate functions). Several previous works have considered such layer-wise training on non-regular architectures for specific hierarchical classes of functions, including [Allen-Zhu,Li,2020] and [The staircase property, Abbe, Boix, Brennan, Bresler, Nagaraj, 2020]. It is unclear how this paper contributes in terms of novel ideas in that direction.
- Overall, the architecture choice and algorithm makes the analysis straightforward. It reduces the problem to sequentially fitting two linear random feature models, which is by now quite well understood. See for example [Generalization error of random feature and kernel methods: Hypercontractivity and kernel matrix concentration, Mei et al., 2022] which gives the sample size and network width to learn any degree- polynomials (for spherical data). The main technical innovation is Lemma 2, which relies on the specific construction of assumption 4.
- For these reasons, I wonder how much this analysis can be generalized to other settings. Especially, in more realistic settings, we expect a non-linear evolution of the parameters, which this paper avoids using their specific architecture. In contrast, while [Nichani et al., 23] only considers one step size (so ultimately also a linear model step), they use a more standard architecture (but get a worse sample complexity dependency).
问题
- Why not directly consider a general function instead of a polynomial? The random feature analysis shouldn’t change much in that case (it simply requires to show a -approximate certificate using uni-dimensional random features).
Thank you to the reviewer for your detailed comments. We address specific concerns of yours below:
“The architecture and algorithm are chosen specifically to succeed for this specific class of target functions (composition of degree-k multivariate polynomials with univariate functions).”
-
Please refer to the global response for discussion on the architecture/algorithm.
-
We would additionally like to emphasize that a major insight of our paper is that the sample complexity of learning is the same as learning , by first fitting the low-degree terms of (which is approximately , as proven in Lemma 2). Furthermore, this algorithm can be implemented by a three-layer neural network, demonstrating a mechanism by which three-layer networks can perform feature learning. We believe that this insight is both novel and interesting.
“Several previous works have considered such layer-wise training on non-regular architectures for specific hierarchical classes of functions”
- (Allen-Zhu & Li; 2020) considers a more restricted class of neural network, in which all activations are quadratic. Furthermore, their target functions of interest are of the form , where . As such, the learned feature already achieves vanishing test loss. Our setting is much more challenging, in that the learner must extract while only having access to ; the learner must also fit in order to obtain vanishing test error. We emphasize here that our choice of link function is very general, while Allen-Zhu & Li; 2020 can only handle which are almost linear. Our insight that a three-layer neural network can indeed implement both of these steps while only having access to is indeed a novel contribution compared to this prior work.
- (Abbe et al., 2020) also considers a network with quadratic activations, and a non-standard sparse architecture and training procedure that trains each neuron one-by-one. Furthermore, the “staircase” functions they consider are learnable by GD on a two-layer network with a sample/time complexity of O(d) (Abbe et al., 2022). Our work demonstrates that three-layer neural networks can learn a new class of hierarchical functions that is not known to be efficiently learnable by a two-layer network.
“Overall, the architecture choice and algorithm makes the analysis straightforward. It reduces the problem to sequentially fitting two linear random feature models, which is by now quite well understood.”
- While each stage of our algorithm indeed fits a random feature model, the overall algorithm of sequentially fitting random feature models is not a kernel method, and crucially can learn hierarchical polynomials of the form with a much improved sample complexity. This fact, along with the observation that a three-layer neural network can indeed implement this layerwise training procedure, has not appeared before in the literature and we believe is an interesting and important contribution.
“For these reasons, I wonder how much this analysis can be generalized to other settings. Especially, in more realistic settings, we expect a non-linear evolution of the parameters”
- Please refer to the global comment; we would like to emphasize that prior works on two or three-layer networks also rely on layer-wise training procedures. It is straightforward to generalize our results to certain forms of non-linear parameter evolution, such as when is in the NTK regime.
“Why not directly consider a general function g instead of a polynomial? The random feature analysis shouldn’t change much in that case.”
- Currently, our proof of Lemma 2 requires expanding into a sum of products of monomial terms, and computing the projection of each individual term. As such, the current proof only applies for polynomials. However, we expect the lemma to still hold for more general , albeit with some more technical challenges to get the proof to go through. Overall, the high level insight that three-layer neural networks perform hierarchical feature learning should still remain for general .
We are happy to answer any more questions you may have, and we hope that if your concerns have been addressed you would consider raising your score.
Thank you to all the reviewers for your detailed comments and feedback.
We’d first like to address comments that were shared by a few of the reviewers.
- Reviewers oysL and RCBe found our architecture and algorithm to be too simplistic, and tailored for the task of learning hierarchical polynomials. Firstly, our architecture is reminiscent of the ResNet architecture, and is similar to that studied in prior work on three-layer neural networks (Ren et al., 2023; Allen-Zhu & Li, 2019; 2020).
- Additionally, we’d like to argue that analyzing the gradient descent dynamics when all parameters are trained jointly is extremely challenging, even in far simpler problem settings. In fact, essentially all prior work on feature learning in two-layer neural networks (Damian et al., 2022; Abbe et al., 2022; 2023; Bietti et al., 2022, etc.) relies on layerwise training procedures. Existing guarantees for three-layer networks beyond the kernel regime, such as (Allen-Zhu & Li, 2019; 2020; Nichani et al., 2023) also make nonstandard modifications to the algorithm/architecture, such as quadratic activations, modified variants of GD or layerwise training.
- It is straightforward to extend our training procedure in stage 1 to certain scenarios where are trained jointly. For example, one can utilize NTK-style analyses to show that training jointly leads to learning in samples (see, for instance, Montanari & Zhong 2020). Since this results in the same overall sample complexity and demonstrates the same high level insight as the random feature setting, for simplicity we decided to opt for the setting where only is trained.
We’ve also uploaded a revised version of our paper, with the following changes:
- We’ve added experiments to Appendix A, which demonstrate that a three-layer residual network trained via standard training procedures (all layers trained jointly via Adam) indeed learn the hierarchical target by learning the feature . This experiment demonstrates that the task of learning hierarchical polynomials is still learned via standard training procedures, and we hope that this helps alleviate the concerns of reviewers oySL and RCBe.
- We’ve relaxed assumption 6 to allow for to be any degree polynomial. The network now requires biases in the first layer as well.
- We’ve modified the proof to work in the overparameterized regime (), which was a concern of reviewer RCBe. Our original generalization analysis was a bit loose, and we’ve edited the paper with a more refined analysis that works for .
We've also improved some of the exposition to further emphasize the key points from the rebuttal. We are happy to answer any more questions the reviewers may have.
Andrea Montanari and Yiqiao Zhong. The interpolation phase transition in neural networks: Memorization and generalization under lazy training.. Annals of Statistics, 2022.
Thank you again to all the reviewers for your detailed feedback on our submission.
Given that the discussion period ends tomorrow, we would like to know whether our rebuttal has addressed your concerns, and whether any more questions of yours remain. Please do let us know, we are happy to answer any additional questions.
Best, The Authors
The paper studies a three-layer neural network for feature learning trained by layer-by-layer gradient descent. The authors show that in the first stage of training, the network is equivalent to a random feature model (f_1(x)) where (W_1) is fixed. In the second stage of training, the network is also a random feature model, but it can be shown to be more expressive than the first stage model. The authors provide theoretical guarantees on the generalization error of the network.
The reviewers agree that the paper is well-written. Some reviewers thought that the results in the paper may improve our understanding of feature learning in gradient descent-trained neural networks. On the other hand, the reviewers raise several concerns about the paper: the assumptions made by the authors are quite strong; the network studied is quite simple and basically collapses to studying fitting two random feature models which is quite well understood; the paper does not provide enough experimental evidence to support the claims made by the authors.
为何不给更高分
The studied model is quite simplistic, and the results obtained closely relate to other results for related random feature models.
为何不给更低分
Understanding feature learning in neural networks trained by gradient descent is a tough problem, thus justifying the simplistic setting studied in the paper. Some of the concerns raised by the reviewers were addressed during the rebuttal stage (e.g., further empirical validation of the theoretical results).
Accept (poster)