Neural network learns low-dimensional polynomials with SGD near the information-theoretic limit
摘要
评审与讨论
The paper studies the classical problem of the single index model over Gaussian inputs, i.e. and for an unknown direction . Information theoretically, one needs samples to learn this function class. The paper shows that reusing the batches with a certain SGD-based training on a 2-layer neural network achieves vanishing error with samples--that is this SGD algorithm by some reused batches learns this function class at nearly the information-theoretic limit.
The problem has been studied extensively. CSQ lower bound in terms of ``information exponent'' , the lowest degree of a non-zero Hermite coefficient of the link , was considered the correct complexity measure for SGD. That is samples are necessary. However, when one reuses the batch, this can be seen as a non-correlation query on the same example and hence CSQ lower bound is breached, allowing us to learn at sample complexity, irrespective of the information exponent.
优点
- The paper addresses an important question in our understanding of the complexity of gradient descent-type algorithms on regular neural networks. There have been considerable efforts devoted to understanding this. This paper goes beyond most (if not all) of these works by analyzing an SGD with reused batches, over vanilla batch SGD. While a recent paper of [DTA+24] provided the first evidence of the benefit of reusing the batches and that CSQ bound can be escaped, the paper goes beyond this and in an important way as follows.
- This paper considers strong recovery which is more satisfying, and technically much more challenging (in contrast to [DTA+24] that only considered weak recovery). To achieve this, there were important pieces to be figured out, which the authors successfully did. This paper provides a clear end-to-end analysis and establishes the learning guarantee in contrast to [DTA+24].
缺点
I do not see any major weaknesses. The training procedure is slightly non-standard, but it is completely understandable from a technical point of view. The layer-wise training and the use of the projected gradients are completely standard in theoretical research. However, I could not see a clear motivation/need for the momentum for the first-layer training. I was wondering if it can be avoided.
问题
- Can the use of momentum be avoided in the first layer of training? If "no" then, in what way this is important in the current analysis?
- Why cite [Gla'23] (on line 338 page 9) for the point of adversarial noise in SQ vs non-adversarial noise in GD?
- Do the authors believe what was said in lines 337-338:
It is also possible that SGD can achieve a statistical complexity beyond the SQ lower boundfor the generative exponent more than two? While I understand this is not ruled out (and it is indeed important to make the point of non-adversarial noise), saying that ``it is possible" sounds slightly strong to me as if the authors are indicating their belief about the situation. Unless there is any strong experimental evidence for this or the authors truly believe this is possible, I would encourage the authors to reword this part. - Should the abbreviation on line 31 page 1 be CSQ instead of SQ?
局限性
None.
We thank the reviewer for the thoughtful comment and constructive feedback. We address the technical concerns below.
"Can the use of momentum be avoided in the first layer of training?"
We use an interpolation step to improve the signal-to-noise ratio in the gradient; this is crucial in the generative exponent 2 setting where the signal and noise are almost at the same magnitude -- please see our response to Reviewer p2Vt for more details and a recap of our explanation in Section 3.2.
It is possible that other modifications or hyperparameter choices of the learning algorithm also achieve a similar effect, but we do not pursue these in the current submission.
[Gla'23] and the possibility of going beyond SQ
Thank you for the close reading. We realize that this remark on future direction is misleading. We initially thought that for the -parity problem, the SQ lower bound suggests that samples are required for rotationally invariant algorithms with polynomial compute (as opposed to in the Gaussian case). However this is not the case, and the total computation in [Glasgow 23] only matches that suggested by the SQ lower bound .
We will include appropriate references on the gap between statistical and adversarial noise, such as [Dudeja and Hsu 20] [Abbe and Sandon 20] (albeit with highly nonstandard architecture).
We also agree with the reviewer that the statement needs to be reworded in the absence of empirical evidence.
We would be happy to clarify any further concerns/questions in the discussion period.
Thanks for your response! I would be happy to see this paper accepted.
This paper addresses the problem of learning single-index targets with polynomial link functions under Gaussian inputs. The authors demonstrate that using SGD on a two-layer fully connected network with a specific activation function can learn such targets with O(d poly(log(d))) samples. The analysis involves the reuse of samples, which allows to improve previous bounds obtained for online single-pass SGD.
优点
S1) The paper advances the analysis of the complexity of learning Gaussian single-index models, a currently very popular model for the theoretical study of neural networks. This contribution is thus significant for the deep learning theory community.
S2) The technical contributions are novel and well presented.
S3) The paper builds on previous work that demonstrated the benefits of re-using batches for learning single-index models, providing concrete evidence of strong learnability of the target by SGD on a shallow network.
缺点
W1) The authors provide minimal empirical validation, with only one experiment demonstrating their claims. They do not address more standard SGD practices, such as training both layers simultaneously, using standard activations/initializations, or employing larger learning rates.
W2) The analysis relies on several theoretical assumptions and is limited to a very structured data distribution, which is common in deep learning theory proofs. While the assumptions are well stated, there is little discussion on whether or how these assumptions could be relaxed.
问题
Q1) Can you clarify how many hidden neurons N are needed for the main result to hold?
Q2) How does the bound depend on the degree of the target q?
Q3) Would you expect the same result to hold if the bias weights are trained?
Q4) Do you have any high-level intuition on why the even polynomials are harder than the odd ones? Do you believe that the poly(log(d)) terms are needed for the even ones?
Q5) Do you expect a similar analysis to hold for other losses? For example, L1 loss.
Q6) Do you have an intuition for what could be an optimal mini-batch re-use schedule?
Typos/suggestions: Proposition 4ii): Can you formally define 'the odd part of f'? Line 211: missing the word 'high'. Line 326: 'high' -> 'weak'. Line 175: missing the word 'be'. Line 163: typo 'not be not'.
局限性
The authors adequately addressed the limitations of their work.
We thank the reviewer for the thoughtful comment and constructive feedback. We address the technical concerns below.
Standard SGD practices, relaxing assumptions
We agree that Algorithm 1 deviates from the most standard training procedure in practice. Note that layer-wise training and fixed bias units are fairly standard algorithmic modifications when proving end-to-end sample complexity guarantees, as seen in many prior works on feature learning [DLS22] [AAM23]. Without such modifications, stronger target assumptions are generally needed, such as matching link function [BAGJ21]. Below we discuss the possible extensions to more general settings.
-
Simultaneous training. We believe that the statistical efficiency of Algorithm 1 can be achieved by simultaneous training of all parameters, under appropriate choice of hyperparameters. For instance, the layer-wise procedure may be approximated by a two-timescale dynamics [BBPV23], and the required ``diversity'' of bias units may be met when the learning rate of the bias weights is sufficiently small. As mentioned in the Conclusion section, it is also interesting to consider a more standard multi-pass algorithm instead of the currently employed data-reuse schedule.
-
Different losses. We focus on the squared loss as it is the most standard objective for regression tasks. Note that the restriction to correlational information is a feature of the squared/correlation loss -- see Section 2.2. If we employ a different loss such as L1 loss, it is possible that SGD can implement non-correlational queries just from the loss function itself -- such direction is tangential to our analysis, since our goal is to show that such non-CSQ component naturally arises from reuse of training data.
Number of neurons and dependence on target degree
The required width of the neural network is almost dimension-free. In particular, we only need to set to achieve population error -- see line 245 for details.
On the other hand, our big- notation hides a constant that might depend exponentially on the target degree . This is due to the sign of Hermite coefficients required in Assumption 3, and the compactness argument to uniformly upper-bound in Proposition 4. Note that similar dependence on is also present in tailored algorithms for learning low-dimensional polynomials [CM20].
"why the even polynomials are harder than the odd ones"
Intuitively speaking, the neural network is initialized at an (approximate) saddle point when is even, since the expectation for any . See [BAGJ21] for more discussions.
To elaborate, let us consider one-pass SGD and evaluate the scale of population gradient of one neuron at random initialization. Let , we have Therefore, when , the scale of the population gradient is , while when and , it is with high probability.
Similarly for reuse-batch SGD, the population gradient is a linear combination of . When is even, one can similarly see that the scale of each is .
"what could be an optimal mini-batch re-use schedule?"
We intuitively expect that any mini-batch size between to with would achieve similar sample complexity. But for simplicity, we employed , which excludes the correlation between different samples.
Investigating the benefit of more intricate mini-batch schedule is an interesting direction for future work.
We would be happy to clarify any further concerns/questions in the discussion period.
I thank the authors for their response, which addressed all my concerns. I will keep my score.
This paper studied the problem of learning single index models under Isotropic Gaussian distribution. The target model is a polynomial function composed with a one-dimensional structure , where the polynomial is of degree at most and has information component (i.e., the order of the first non-zero Hermite polynomial expansion) . The paper studied the sample and computational complexity of learning such single index models with 2-layer neural networks, using gradient descent type methods. The critical aspect of the algorithm is reusing the same batch of samples every two iterations, which frees the algorithm from CSQ constraints and becomes an SQ algorithm. A critical observation is that by reusing samples at each iteration, the algorithm induced a monomial transformation of the labels, which effectively reduced the information component from to less than 2. Hence, the algorithm can achieve near-optimal sample complexity .
优点
- This paper is clearly written with intuitions and useful explanations.
- This paper provides new perspectives on designing SQ algorithms to learn single-index models using neural networks. Though the idea of reusing samples has already appeared in prior works ([DTA+24]), this work shows that resuing samples can achieve strong recovery of the hidden direction and provides a well-rounded analysis of the sample and computational complexity. Importantly, the authors showed that by reusing the mini-batches, one can learn the target model with samples, which is near the information-theoretic limit.
- This paper provides a very interesting intuition on reducing the information exponent of the link function using monomial transformations, which could be of independent interest for future works.
缺点
-
Though the authors claimed that they were using neural networks to learn the single index model, the activation of each neuron turns out to be a combination of polynomials. Hence, the neural network is essentially a linear combination of Hermite polynomials. In this case, I am wondering what the differences are between using the 'neural network' to learn the single index models and using polynomials to learn the single index models, which is already done in [CM20]. Of course, [CM20] requires a warm start procedure, which is not a gradient descent type algorithm, but I think it would be more interesting if that analysis is carried out on conentional neural networks like ReLU networks.
-
The authors hide many constants in the big-O notations. However, I am skeptical that all those parameters are independent of the dimension . For example, in the proof of Proposition 4, the upper bound on is . However, there is no actual lower bound on other than that being non-zero. Therefore, I am wondering if it is possible that can be as small as ? I think the paper will be more theoretically sounded if the authors can explicitly present the dependence on the parameters , etc. in the final bounds on the sample complexity and iteration complexity.
问题
- Do ReLUs and sigmoids satisfy the assumption 2?
- I think there are typos on line 593 to 595. What is 'x' on the right-hand side of line 593 and 595?
- Since only neurons satisfy assumptions 2 and 3, does it imply that the width of the network is at least ?
- Since Theorem 2 relies on neurons that satisfy assumptions 2 and 3, does it imply that having only an fraction of good neurons (neurons with ) is enough to achieve small error? What is the intuition behind this?
局限性
The authors addressed the limitations of the paper and provided inspiring future directions.
We thank the reviewer for the thoughtful comment and constructive feedback. We address the technical concerns below.
Polynomial activation function
We make the following clarifications.
- Note that all square-integrable activation functions (ReLU, sigmoid, etc.) can be written as a linear combination of Hermite polynomials. The key differences between our algorithm and [CM20] are as follows,
-
[CM20] employed a label transformation (based on thresholding) prior to running SGD, whereas we use the squared loss without preprocessing and extract the transformation from reusing training data.
-
[CM20] considered optimization jointly over the low-dimensional subspace (finding index features) and the coefficients of the polynomials. In contrast, in our setting the coefficients of the polynomial activation are fixed, and we optimize the parameters of the neural network.
- We restrict ourselves to polynomial activation because it is easy to construct coefficients that satisfy Assumption 3 (required for strong recovery). For strong recovery, ReLU or sigmoid may not suffice, and the use of well-specified or polynomial activations is common in the literature, e.g., [AGJ21][AAM22][AAM23][DNG+23].
As for weak recovery, Assumption 2 is satisfied with probability 1 by a shifted ReLU/sigmoid (e.g., see Lemma 15 in [BES+23]). Therefore, we can establish weak recovery using standard choices of activation. We will comment on this in the revised manuscript.
Dimension dependence in constants
All constants in our theorems are dimension free. Specifically, the lower bound on does not depend on dimensionality, since Proposition 4 is for univariate functions . We notice that there is a typo in Appendix A: the expectation in A.1 should be with respect to since is a scalar function. We apologize for the confusion this may cause.
"does it imply that the width of the network is at least ?"
Yes, the required student width is exponential in the target degree , which is treated as a constant the big- notation.
Although only a small fraction of neurons can achieve strong recovery, the entire neural network can achieve small error because these ``good'' neurons can be singled out in the second-layer training, as shown in Proposition 3.
Note that similar dependence on is also present in tailored algorithms
for learning low-dimensional polynomials [CM20].
We would be happy to clarify any further concerns/questions in the discussion period.
Please engage with the authors' response: have they addressed your concerns adequately, and how has your score been affected?
Best, your AC
I thank the authors for their detailed response. I would like to keep my score unchanged.
This manuscript studies the learning properties of two-layer networks trained with SGD reusing the batch. The authors show that this simple modification allows SGD to surpass the limits of CSQ algorithms and learn single-index functions efficiently. The submission considers both recovery of the target features and generalization properties of two-layer networks. The claims are supported by rigorously proven theorems.
优点
The strength of this submission resides in the strong theoretical claims. The questions addressed are of great interest to the theoretical machine learning community.
缺点
This submission has no outstanding weaknesses but the presentation could be enhanced. I will detail in the section below some suggestions to improve the manuscript which sometimes is a bit obscure for a non-expert reader.
问题
- The idea of label transformation implemented by [CM20] could be reported. The contrast with what SGD (with batch reusing) is implementing would give an idea to the reader of the strength of the claims.
- Naively, one might think that sample complexity guarantees might come easily from weak recovery plus [BAGJ21]. Maybe the authors could comment more on the technicalities that arise.
- What is the role of the hyperparameters, e.g. the interpolation one , the learning rate, and all the quantities appearing in the non-CSQ transformation? Are they randomly drawn (and the theorems hold with high probability?
- The authors mention that interpolation is required and correctly dedicate a paragraph to it. However, is not clear to me if they believe is only necessary for the technicalities of the proof or if they believe it is valid in general.
- The authors do an amazing job in introducing the CSQ/non-CSQ parallel when reusing the batch in section 2.2. However, to even more enhance this subsection I think it would be great to state more clearly details over the SQ class. Although SQ is formally defined, the authors could refer in the submission to the lower bounds achievable by SQ. More precisely, it was reported that d is both sufficient and necessary for learning by citing works using AMP-type algorithms. Are these algorithms belonging to SQ?
- Closely related to the above question. If AMP-type algorithms belong to SQ, could the author comment more on the link with [Gla23] and non-adversarial noise and the possibility of going beyond SQ? Of course, the Information Theoretic barrier cannot be broken, but I think more clarity on these points would be more than welcome. I think the insights presented are crucial and interesting and deserve more space in the main body.
局限性
The limitations are addressed in the manuscript.
We thank the reviewer for the thoughtful comment and constructive feedback. We address the technical concerns below.
"The idea of label transformation implemented by [CM20] could be reported"
In the current manuscript, the difference between our label transformation and that in [CM20] is discussed in the paragraph starting line 157 and line 250. Specifically, prior SQ algorithms are typically based on thresholding, whereas we use monomial transformations which can be easily extracted from SGD update.
"one might think that sample complexity guarantees might come easily from weak recovery plus [BAGJ21]"
The difficulty is discussed in the paragraph starting line 162. In particular, since the link function is unknown, we cannot directly employ the argument in [BAGJ21] which assumed a well-specified model, i.e. . Instead, we make use of the randomized Hermite coefficients of the student activation function to translate weak recovery to strong recovery.
"What is the role of the hyperparameters ... and all the quantities appearing in the non-CSQ transformation?"
Most hyperparameters in Algorithm 1 are deterministic, with the only exceptions being the sign of student activation function (see Lemma 2) and the momentum parameter randomized over student neurons.
This aims to guarantee that, for any target function, there exists some student neurons achieving strong recovery (Theorem 2).
We realize that there is a typo in Theorem 1: "with probability" should be replaced by "with high probability". We apologize for the confusion this may cause.
We will clarify this in the revision.
The role and necessity of interpolation
We use an interpolation step to improve the signal-to-noise ratio in the gradient; this is crucial in the generative exponent 2 setting where the signal and noise are almost at the same magnitude (see Section 3.2 and paragraph starting line 307 for details). It is possible that other modifications or hyperparameter choices of the learning algorithm also achieve a similar effect, but we do not pursue these in the current submission.
Below we recap the intuition provided in Section 3.2 on the failure of the standard online SGD analysis. When we analyze the training dynamics of learning single-index model, we characterize the progress via the projection of onto , which we refer to as the alignment (e.g., [DNG+23]). We want to show that increases from . Given the gradient (assumed to be orthogonal to for simplicity) and step size , the update of alignment is given as
One sees that for the expectation to be larger than , should be larger than . To achieve this, we can simply let sufficiently small ( for ); because the signal term linearly depends on , while the noise term quadratically.
However, in our case, the signal comes from non-CSQ term. For example, when , the signal term is proportional to (under ). Therefore, decreasing does not improve the SNR. The interpolation step provides a remedy by preventing the parameters from changing too fast and reducing the projection error.
"it would be great to state more clearly details over the SQ class"
Thank you for the valuable suggestion. We will discuss the SQ complexity and generative exponent in [DPVLB24] in more details. Indeed our analysis is based on the observation that SGD with reused batch can implement SQ, and hence samples are sufficient to learn target functions with generative exponent at most 2; this sample complexity is also achieved by AMP algorithms (typically assuming knowledge of the link function to construct the optimal preprocessing), and is consistent with the SQ lower bound.
"could the author comment more on the link with [Gla23] and non-adversarial noise and the possibility of going beyond SQ?"
Thank you for the close reading. We realize that this remark on future direction is misleading. We initially thought that for the -parity problem, the SQ lower bound suggests that samples are required for rotationally invariant algorithms with polynomial compute (as opposed to in the Gaussian case). However this is not the case, and the total computation in [Glasgow 23] only matches that suggested by the SQ lower bound . We will include appropriate references on the gap between statistical and adversarial noise, such as [Dudeja and Hsu 20] [Abbe and Sandon 20].
We would be happy to clarify any further concerns/questions in the discussion period.
I thank the authors for their rebuttal that clarified my concerns. After carefully reading it along with other reviewers’ comments I would like to keep my score as in the original review.
Thank you for the update.
We would appreciate knowing if there are any outstanding concerns that may have led to the reviewer's decision to maintain the current score.
We would be more than happy to provide further clarifications during the discussion period.
Dear reviewers,
Thank you for taking the time to review the paper. Notice that the discussion period has begun, and will last until August 13 (4 more days).
During this time your active participation and engagement with the authors is very important, and highly appreciated. Specifically, please read the responses, respond to them early on in the discussion, and discuss points of disagreement.
Thank you for your continued contributions to NeurIPS 2024.
Best, Your Area Chair.
Dear Reviewers and Area Chair,
We appreciate your continued time and effort in providing detailed feedback on our paper. As the rebuttal period comes to a close, we would like to briefly summarize our responses and revisions, which we believe have addressed all of the reviewers' concerns.
- Reviewers hk3C and 6kdk inquired about the role of the momentum step. We explained that its purpose is to improve the signal-to-noise ratio, which is crucial in the setting, where the signal and noise are nearly at the same scale.
- Reviewers Mv1X and hk3C raised questions on the constants in our big- notation. We clarified the dependence on the target degree , and addressed Reviewer Mv1X's misunderstanding regarding the dimension dependence in the constants.
- We discussed the motivation behind the layer-wise training procedure and the possibility simultaneous training and the use of different loss functions, in response to Reviewer hk3C's concern about standard SGD training.
- Reviewers Mv1X and p2Vt asked about the differences between our analysis and that of [CM20]. We highlighted that our work focuses on SGD training of a two-layer neural network, without label preprocessing (with thresholding) or optimization of the polynomial coefficients.
To recap the main findings of our submission, we showed that by simple data-reuse, two-layer neural network can learn arbitrary polynomial single-index models near the information-theoretically optimal sample complexity (beyond that suggested by the CSQ lower bound). We built upon the insight is that SGD with reused data can implement non-correlational queries (despite the squared loss), and established end-to-end learning guarantees that goes beyond weak recovery.
We believe that our results will be of interest to the NeurIPS community.
Best regards,
Authors of submission 12755
This paper studies the problem of learning single index models under Isotropic Gaussian distribution. It shows that reusing the batches with SGD-based training on a 2-layer neural network achieves the sample complexity that matches the information theoretic limit up to polylogarithmic factor, which is a significant improvement over existing theoretical results.
All reviewers appreciate the novelty and significant contributions of this paper. Therefore, I recommend accepting it.