PaperHub
5.8
/10
Poster4 位审稿人
最低4最高7标准差1.1
6
6
7
4
3.5
置信度
正确性3.0
贡献度2.5
表达3.0
NeurIPS 2024

The Implicit Bias of Gradient Descent on Separable Multiclass Data

OpenReviewPDF
提交: 2024-05-16更新: 2024-11-06
TL;DR

We prove implicit bias of gradient descent for linearly separable multiclass problems.

摘要

关键词
gradient descentmulticlass classificationhard-margin SVMimplicit bias

评审与讨论

审稿意见
6

The authors study the problem of multiclass SVM in the realizable (i.e., separable) case. They show that under several assumptions on the loss function, the gradient descent converges in direction toward the (unique) solution of the hard-margin problem. For that, they used the notion of Permutation Equivariant and Relative Margin-based losses (introduced elsewhere) in order to extend the notion of exponential tail (from binary SVM) to multiclass SVM. They also show that three loss functions, including cross-entropy, satisfy the assumptions of the theorem.

优点

This is a strong technical paper (even though I would like to ask several questions to the authors to be convinced of the soundness of the proof; see "Questions" section).

缺点

The main weakness of the paper in my opinion is the difficulty to understand several technical steps of the proof due to the lack of details/explanations. This is a general comment but I have several specific questions in the "Questions" section about steps that I did not understand and made me question the overall soundness of the proof. I would be happy to have answers from the authors about these questions.

Other comments:

  • Line 193: Lemma 12 in the other paper does not directly give the result, one needs to browse to Eq 18 to complete the argument.
  • Eq. 20: η\eta missing somewhere.

问题

  • Equation after line 200, first line: the equality seems wrong to me because I would expect a factor log(1+t1)\log(1+t^{-1}) multiplying w^\lVert\hat{w}\rVert, but it is not there. Furthermore, I do not understand either how we get the inequality simply by showing that the third term is negative, since there is an extra t2t^{-2} factor multiplying w^\lVert\hat{w}\rVert.
  • Line 262: Why is u±=0u_{\pm}=0, while in Def. 2.2, it can be anything?
  • Lines 609 to 611: I understand from the argument that at least one component of uu must converge to \infty, because indeed ϕ(u)0\nabla\phi(u)\prec0 for all finite uu, but why do all the components of uu have to converge to \infty?

局限性

The authors acknowledge correctly the limitations of their work. However, a small discussion somewhere in the paper of how restrictive the assumptions on the loss function are would be welcome.

作者回复

Line 262: Why is u±=0u_{\pm} = 0, while in Def. 2.2, it can be anything?

We could have worded this better, thank you for the question! All the relative margins (which comprise u\mathbf{u}, the vector argument to the template) go to infinity so that lets us pick any finite value for u±u_{\pm}. The value 0 in particular works for us because the 3 losses we analyze (cross-entropy, exponential loss, PairLogLoss) satisfy exponential tail with this parameter setting, as proven in Appendix C. We will re-word this in the paper to avoid any confusion in readers' minds.

Lines 609 to 611: I understand from the argument that...

Thank you for the careful reading! We note that one can easily verify this property for the 3 losses analyzed in our paper. Nevertheless, we are able to fill this gap for general losses using a new structural result on convex, symmetric, differentiable functions that we proved. The caveat is that now, our results require an additional assumption that the loss template be convex. It is easy to see that all 3 losses analyzed are convex by computing the second derivatives and verifying they are non-negative. We note that this condition is not required in the binary case in Soudry et al 2018. This is an interesting difference between the binary and multiclass case which we hope to address in future work. \newcommand{\bbR}{\mathbb{R}} \newcommand{\bfx}{\mathbf{x}} \newcommand{\bfu}{\mathbf{u}} \newcommand{\bfv}{\mathbf{v}}

Our new structural result is stated below. First we need an additional piece of notation: Given a vector \bfx\bbRn\bfx \in \bbR^{n} and a real number C\bbRC \in \bbR, define \bfxC\bbRn\bfx \vee C \in \bbR^{n} to be the vector such that the ii-th component of \bfxC\bfx \vee C is equal to max(xi,C)\max(x_{i}, C), for all i[n]i \in [n].

Theorem. Suppose that f:\bbRn\bbRf : \bbR^{n} \to \bbR is a symmetric, convex, and differentiable function. Then for any real number C\bbRC \in \bbR and any \bfx\bbRn\bfx \in \bbR^{n}, we have fxi(\bfx)fxi(\bfxC)\tfrac{\partial f}{\partial x_{i}} ( \bfx ) \le \tfrac{\partial f}{\partial x_{i}} (\bfx \vee C) for any iargmin(\bfx)i \in \mathrm{argmin}(\bfx).

Using this theorem, we show that if not all components of \bfu\bfu go to infinity, then we can derive a contradiction to ψ(u)0\nabla \psi(u) \rightarrow \mathbf{0}, using the exponential tail property. We sketch the details of this.

  1. Suppose that \bfut\bfu^t is a sequence such that limtψ(\bfut)0\lim_t \nabla \psi(\bfu^t) \to 0, but there exists a component jj such that ujtu^t_j does not go to infinity.
  2. There must be a finite number MM such that ujtMu^t_j \le M for infinitely many tt. We pass to the subsequence so that ujtMu^t_j \le M always.
  3. Define C:=max(2u,M,log(12(K1)))C := \max (\quad 2|u_{-}|, \quad M,\quad -\log (\tfrac{1}{2(K-1)}) \quad ).
  4. Define \bfvt:=\bfutC\bfv^t := \bfu^t \vee C. Then min(\bfvt)C\min(\bfv^t) \geq C. Using the lower exponential tail bound combined with Clog(12(K1))C \geq -\log (\tfrac{1}{2(K-1)}), this implies that ψ(\bfvt)vi12cexp(avi)-\tfrac{\partial \psi(\bfv^t)}{\partial v_i} \geq \tfrac{1}{2}c\exp(-av_i).
  5. Applying the theorem for all iargmin(\bfut)i \in \mathrm{argmin}(\bfu^t), we get ψ(\bfut)uiψ(\bfvt)vi12cexp(avi)<0\tfrac{\partial \psi(\bfu^t)}{\partial u_i} \leq \tfrac{\partial \psi(\bfv^t)}{\partial v_i} \leq -\tfrac{1}{2}c\exp(-av_i) < 0. This contradicts what we proved on line 608, i.e. limtψ(\bfut)0\lim_t \nabla \psi(\bfu^t) \to 0.

We are happy to follow up with more details of this sketch if you are interested.

评论

Could the reviewer kindly let us know if there are any remaining concerns we should address? We would love an opportunity to address any further questions and comments before the author discussion period expires.

审稿意见
6

This paper leverages the PERM (Permutation Equivariant and Relative Margin-based losses) framework proposed in [Wang and Scott, 2024], and extends the implicit bias result of binary classification to multiclass classification.

Specifically, the authors extend the exponential tail property to multiclass setting, and proved that for almost all linearly separable datasets with suitable PERM loss, with sufficiently small step size, the gradient descent iterates directionally converge towards the hard-margin multiclass SVM solution.

优点

This paper is strongly motivated, well structured and pleasant to read. It studies the problem of implicit bias for multiclass classification with losses not limited to cross-entropy, bridging the binary-multiclass gap.

The PERM framework provides an elegant tool for analyzing multiclass classification problems, and in my opinion, this paper is technically solid. I have not gone through the proofs of lemmas in the appendix, but the analysis in the main paper looks correct to me.

缺点

As the authors have discussed in the Limitations section, it would also be interesting to consider beyond the asymptomatic setting and end tail property.

Some typos:

  • It has been discussed in 4.3 that R\mathcal{R} is βσmax2(X)\beta\sigma_{\textrm{max}}^2(X)-smooth, and the learning rate should be η<2βσmax2(X)\eta<\frac{2}{\beta\sigma_{\textrm{max}}^2(X)}, but the upper bound is stated as 2β1σmax2(X)2\beta^{-1}\sigma_{\textrm{max}}^2(X) in the statement of Theorem 3.4.
  • In page 3 line 85, the meaning of [v]σ(j)[**v**]\sigma(j) doesn’t seem clear to me, is it [v]σ(j)[**v**]_{\sigma(j)}?
  • Line 116: from direct calculation ΥyDv{{\Upsilon}_y}**Dv** is not (vyv1,vyv2,,vyvK)T(v_y-v_1,v_y-v_2,\cdots,v_y-v_K)^T but its permutation. Although this difference should have no influence in the proofs under PERM loss assumptions, it could still be a bit misleading.
  • Line 200: in the RHS of the equation, w^\left\|\hat{**w**}\right\| should be w^2t2\|\hat{**w**}\|^2t^{-2}
  • Line 219: w^TR(w)\hat{\textrm{w}}^T\mathcal{R}(**w**) should be w^TR(w)\hat{\textrm{w}}^T\nabla\mathcal{R}(**w**)

问题

  • As the authors have mentioned in Section 5, it has been proved in [Shamir, 2021] that for gradient-based methods in binary classification, throughout entire training process, both empirical risk and generalization error decrease at an optimal rate, I am curious if similar phenomenon has been observed in numerical experiments for multiclass setting?

  • It seems to me that the condition learning rate η<2βσmax2(X)\eta<\frac{2}{\beta\sigma_{\textrm{max}}^2(X)} could be strong for dataset XX with a large spectral norm. I’m curious if it is more of a necessity for the proof, or the numerical experiments do require this condition to have directional convergence?

局限性

N/A

作者回复

Thank you for the careful reading and catching these typos! We will update our paper to reflect these changes. We also clarify line 116.

In page 3 line 85, the meaning of [v]σ(j)[\mathbf{v}]\sigma(j) doesn't seem clear to me, is it [v]σ(j)[\mathbf{v}]_{\sigma(j)}?

Yes. This was a typo, we will remove [v]σ(j)[\mathbf{v}]\sigma(j) from the paper.

评论

I thank the authors for the clarification, my rating remains

审稿意见
7

This paper investigates the implicit bias of gradient descent on separable multiclass data using a broad class of losses termed Permutation Equivariant and Relative Margin-based (PERM) losses, which include cross-entropy loss, multiclass exponential loss, and PairLogLoss. The main contribution is the extension of the concept of the exponential tail property, commonly used in the analysis of implicit bias in binary classification, to multiclass classification with PERM losses. The proof techniques (and main results) are analogous to those used in the binary case, suggesting that PERM losses can bridge the theoretical gap between binary and multiclass classification.

优点

  1. The paper is very well-written and well-organized, making it easy to follow.
  2. While the implicit bias of gradient descent has been widely studied for binary classification, its extension to multiclass classification is relatively sparse and mainly restricted to cross-entropy loss. This paper addresses a broad class of PERM losses, extending the concept of the exponential tail through the "template" of the loss. The result is both interesting and significant.
  3. The theoretical results are convincing. While the proof largely follows Soudry et al. [2018], the analysis is nontrivial. Additionally, the proof idea is laid out in a clear manner.

缺点

  1. The main weakness is the lack of numerical results to verify the theoretical claims. While this is understandable given that the convergence is exponentially slow, some simulations could be presented to illustrate the gap between theory and practice (where networks are not trained to achieve zero loss).

问题

Maybe a minor typo: in Assumption 3.3, line 171, kk should not be there.

局限性

Yes.

作者回复

Thank you for your review and catching the typo. We appreciate your time and the supportive feedback. We've added numerical simulations demonstration implicit regularization towards the hard margin SVM when using the PairLogLoss, in line with our theory's prediction. It is attached to the "global rebuttal" above.

评论

I thank the authors for the response. The additional experiments are convincing. My rating remains.

审稿意见
4

This paper uses the framework of permutation equivariant and relative margin-based losses of (Wang and Scott, '24) to extend the implicit bias result of (Soudry et al.,'18) to multinomial classification. Namely, the authors prove that when the loss satisfies a multiclass generalisation of the exponential tail property and the dataset is linearly separable, gradient descent with a sufficiently small learning rate converges to the multiclass max-margin solution.

优点

The paper's goal is clearly stated and framed within the relevant literature. The definitions and statements of the main results are easy to follow and, except for a few inaccuracies, so is the proof sketch.

缺点

From a conceptual point of view, the results of the present paper do not provide additional insight into the implicit bias of multiclassification with respect to previous works limited to the cross-entropy loss.

Also at the technical level, as the authors themselves state, using the PERM framework allows for a simple generalisation of the results of (Soudry et al., '18) to the multiclass case, to the point that a large fraction of the proof is identical or almost identical.

问题

-Definition of x~\tilde{x} in Eq.~(9)? -There must be some typo in the first equation of section 41; -The definition of coordinate projection after Eq. (19) in the proof of Lemma 4.5 is unclear; -The explanation preceding Eq. (20) in the same proof could mention that Eq. (11) is also used.

局限性

The authors have addressed some of the paper's limitations, but not the limited significance with respect to previously existing results on the implicit bias of multiclassification.

作者回复

\newcommand{\bfx}{\mathbf{x}} \newcommand{\bfW}{\mathbf{W}} \newcommand{\bfw}{\mathbf{w}} \newcommand{\bfD}{\mathbf{D}} \newcommand{\mlc}{\boldsymbol{\Upsilon}} \newcommand\pseudoindex[1]{[#1 ]}

From a conceptual point of view, the results of the present paper do not provide additional insight into the implicit bias of multiclassification with respect to previous works limited to the cross-entropy loss.

Our novel insight is characterizing the conditions on a loss that suffice to endow it with the implicit bias towards the max-margin solution. This is significant because, while cross-entropy is the most popular loss used, there are many new losses being proposed that offer competitive performance (PairLogLoss being an example). Our work fills in a gap in the loss design literature regarding how to design losses that have the same implicit regularization property as the cross entropy. Also, our technique can be of wider interest due to the unified treatment of binary and multiclass classification, which your next feedback covers which we now address.

Also at the technical level, as the authors themselves state, using the PERM framework allows for a simple generalisation of the results of (Soudry et al., '18) to the multiclass case, to the point that a large fraction of the proof is identical or almost identical.

Our work offers a unified treatment of both binary and multiclass, and shows that the proof strategy from the binary case carries over. This is a strength of our analysis. Although some of our proofs simply mirrors the pre-established binary case, a large portion of our analysis is novel and nontrivial. For example, our new definition of exponential tail is itself a novel contribution. This definition captures the specific property that is needed for the implicit bias to hold for a broad class of losses beyond just cross-entropy. Additionally, verifying that existing multiclass losses satisfy our exponential tail definition (as well as the beta-smoothness condition) is nontrivial- the proofs are in Appendix C. Finally, we develop some novel tools and techniques to lay the groundwork for future binary-to-multiclass generalizations in the PERM framework. These techniques are captured in Lemmas 4.1 and 4.2, as well as Appendices B and D. This puts us at about 12 pages of novel analysis not present in (Soudry et al., '18).

Definition of x~\tilde{\mathbf{x}} in Eq.~(9)?

x~\tilde{\mathbf{x}} is defined on line 187.

The definition of coordinate projection after Eq. (19) in the proof of Lemma 4.5 is unclear;

This is an understandable confusion since we use a lot of different notation. Additionally, thank you for helping us notice a typo. On line 254, instead of "the \bfx~i,i\tilde{\bfx}_{i,i} 0-entry is omitted",

it should say "the \bfx~i,yi\tilde{\bfx}_{i,y_i} 0-entry is omitted". This entry is 0 because looking at the definition on line 187, the A\mathbf{A} matrices will cancel out.

What we mean by our coordinate projection notation is essentially the following:

\pseudoindex\mlcyi\bfD\bfW\bfxik\pseudoindex{\mlc_{y_{i}} \bfD \bfW^{\top} \bfx_{i}}_{k}

is equal to

\bfx~i,k\bfw\tilde{\bfx}_{i,k}^{\top} \bfw

if k<yik < y_i. else if kyik \geq y_i, we have

\bfx~i,k+1\bfw\tilde{\bfx}_{i,k+1}^{\top} \bfw

Please also see the beginning of Appendix D for an intuitive explanation. Finally, we have added a remark in the paper to clarify this.

The explanation preceding Eq. (20) in the same proof could mention that Eq. (11) is also used.

Thank you for your suggestion; we will make this fix.

评论

I thank the authors for clarifying my question and appreciate their effort. Nevertheless, my concerns about the relevance of the contribution, as listed in the weaknesses section, still stand, hence I leave my mark to 4.

The extension of (Soudry et al, '18) and other previous works to more general multiclass losses is surely interesting, but, in my opinion, it would be better suited as a technical paper in a dedicated journal. As an example, I find it difficult to appreciate the technical contributions of the lemmas mentioned in the author's reply (4.1, which seems to stem from a simple derivation of the gradient, and 4.2, a simple consequence of the definition of the trace) and I suspect that the broad audience of NeurIPS would too.

作者回复

Thank you all for the fantastic reviews! Please find attached figures to numerical simulations.

One typo that was brought to our attention was in the first equation on line 200. Here is the correct inequality, derived:

r(t+1)r(t)2=w(t+1)w(t)+w^log(t+1)w^log(t)+w~w~2\| \mathbf{r}(t+1) - \mathbf{r}(t)\|^2 = \|\mathbf{w}(t+1) - \mathbf{w}(t) + \hat{\mathbf{w}}\log(t+1) - \hat{\mathbf{w}}\log(t) + \tilde{\mathbf{w}} - \tilde{\mathbf{w}}\|^2

=ηR(w(t))+w^log(1+t1)2=\|-\eta\nabla\mathcal{R}(\mathbf{w}(t)) + \hat{\mathbf{w}}\log(1+t^{-1}) \|^2

=η2R(w(t))2+w^2log(1+t1)22ηw^R(w(t))log(1+t1)= \eta^2\|\nabla\mathcal{R}(\mathbf{w}(t))\|^2 + \|\hat{\mathbf{w}}\|^2\log(1+t^{-1})^{2} - 2\eta\hat{\mathbf{w}}^{\top}\nabla\mathcal{R}(\mathbf{w}(t))\log(1+t^{-1})

η2R(w(t))2+w^2t22ηw^R(w(t))log(1+t1)\leq \eta^2\|\nabla\mathcal{R}(\mathbf{w}(t))\|^2 + \|\hat{\mathbf{w}}\|^2 t^{-2} - 2\eta\hat{\mathbf{w}}^{\top}\nabla\mathcal{R}(\mathbf{w}(t))\log(1+t^{-1})

In the last step we used log(1+x)x\log(1+x) \leq x for all x>0x > 0.

最终决定

This paper provides a variety of refinements of max margin analyses to the multiclass case, using the proof of Soudry et al. Reviewers are uniformly positive and I recommend acceptance.