PaperHub
7.5
/10
Spotlight4 位审稿人
最低6最高8标准差0.9
8
6
8
8
2.8
置信度
ICLR 2024

Optimal Sample Complexity of Contrastive Learning

OpenReviewPDF
提交: 2023-09-23更新: 2024-03-15
TL;DR

We provide tight bounds on the sample complexity of contrastive learning in various settings, without any assumptions on the data distribution

摘要

Contrastive learning is a highly successful technique for learning representations of data from labeled tuples, specifying the distance relations within the tuple. We study the sample complexity of contrastive learning, i.e. the minimum number of labeled tuples sufficient for getting high generalization accuracy. We give tight bounds on the sample complexity in a variety of settings, focusing on arbitrary distance functions, $\ell_p$-distances, and tree metrics. Our main result is an (almost) optimal bound on the sample complexity of learning $\ell_p$-distances for integer $p$. For any $p \ge 1$, we show that $\tilde \Theta(nd)$ labeled tuples are necessary and sufficient for learning $d$-dimensional representations of $n$-point datasets. Our results hold for an arbitrary distribution of the input samples and are based on giving the corresponding bounds on the Vapnik-Chervonenkis/Natarajan dimension of the associated problems. We further show that the theoretical bounds on sample complexity obtained via VC/Natarajan dimension can have strong predictive power for experimental results, in contrast with the folklore belief about a substantial gap between the statistical learning theory and the practice of deep learning.
关键词
learning theorysample complexityvc dimensioncontrastive learningmetric learning

评审与讨论

审稿意见
8

Suppose VV is a set of nn data points, each embedded into a mathbbRdmathbb{R}^d. Suppose we observe mm triplets (x,y,z)V3(x,y,z) \in V^3 and their labels in {1,1}\{ -1, 1\} where the label is 11 if x,yx,y are closer than x,zx,z in the embedding space in p\ell_p distance and 1-1 otherwise. The paper gives optimal order for sample complexity mm required to get to a misclassification error of ϵ\epsilon.

The high level technique is to derive VC dimension (and Natarajan dimension for larger tuples, but we ignore results in the larger tuple case for now). The authors derive nearly tight order for VC dimension of such triplet classifiers. For upper bounding the VC dimension, they formulate the classification function as the sign of a polynomial in ndnd dimensions. The key ingredient here is a fact from Warren (1968) that there are at most (4epm/nd)nd(4epm/nd)^{nd} connected components in Rnd\mathbb{R}^{nd} where in each connected component, the signs of the mm polynomials are fixed. For proving the lower bound on VC dimension, the authors give a clever construction of a set of triplets which can be shattered.

Authors give their results to realizable and agnostic cases. They extend their results to several distance functions and tuples of size more than 3. They also study the well-separated case where the labeled triplets (x,y+,z)(x,y^+,z^-) satisfy ρ(x,z)(1+α)ρ(x,y)\rho(x,z) \geq (1+\alpha) \rho(x,y) etc. for some α>0\alpha > 0.

优点

  • The derivations are interesting, short and non-trivial.
  • The results are relevant because contrastive learning is practical.

缺点

Nothing significant.

问题

Typos/Minor comments:

  • page 4: Outline of techniques: P is a polynomial of degree 22, not 2d2d.
  • Reducing dependence on nn: It may be good to state bounds with the assumptions mentioned (kk latent classes etc.)
  • Theorem 3.3 proof, first line: Should it be d<nd < n here?
  • In Definition 2.1, the symbol S3S_3 seems to be used before definition.
  • Please search for "Kulis Kulis" and "Warren Warren" in the paper and remove such duplicates.
  • page 8, second line: Should << be replaced by \leq?
评论

We thank the reviewer for their comments. We updated the PDF, which now includes all of the suggested changes.

Q: Typos:

  • page 4: Outline of techniques: P is a polynomial of degree 22, not 2d2d.
  • Theorem 3.3 proof, first line: Should it be d<nd<n here?
  • Please search for "Kulis Kulis" and "Warren Warren" in the paper and remove such duplicates.
  • page 8, second line: Should << be replaced by \le?

A: Thank you, fixed!

Q: Reducing dependence on n: It may be good to state bounds with the assumptions mentioned (k latent classes etc.)

A: In case there are kk latent classes, all the results would be as in Table 1, with nn being replaced with kk.

Q: In Definition 2.1, the symbol S3S_3 seems to be used before definition.

A: Thank you for pointing it out. Definition 2.1 introduces the notion of sample complexity, and also introduces notation S3S_3 for the sample complexity. We clarified the wording: “the sample complexity of contrastive learning, denoted as S3(ϵ,δ)S_3(\epsilon, \delta)

审稿意见
6

The paper explores the efficacy of contrastive learning, a method for learning data representations based on labeled tuples that detail distance relationships within the tuples. The main focus is on understanding the sample complexity of this method, which refers to the minimum number of labeled tuples needed to achieve accurate generalization.

This work provides specific bounds for sample complexity across various settings, especially for arbitrary distance functions, p\ell_p-distances, and tree metrics. A central finding is that for learning p\ell_p-distances, a minimum of Θ(min(nd,n2))\Theta(\min(nd,n^2)) labeled tuples is sufficient and necessary for depicting dd-dimensional representations of nn-point datasets. These results are applicable regardless of the input samples' distribution and derive from bounds on the Vapnik-Chervonenkis/Natarajan dimension of related problems.

This paper also demonstrates that theoretical boundaries derived from the VC/Natarajan dimension correlate strongly with experimental outcomes.

优点

  1. This paper primarily studies the sample complexity of contrastive learning, and provides tight bounds in some settings, including arbitrary distance functions and p\ell_p-distances for even pp and almost tight bounds for odd pp. For constant dd, this work also provides a matching upper and lower bound.
  2. This paper studies both realizable and agnostic settings, as is standardly considered in PAC learning. The sample complexity bounds in terms of ϵ\epsilon coincide with the standard results in PAC learning in both realizable and agnostic settings.
  3. The proposed proof idea extends to various settings, including the cases where k>1k >1 (multiple negative examples in one tuple) and quadruple samples.

缺点

  1. It would be good to provide a thorough comparison with the known sample complexity bounds that appeared in the existing literature.
  2. While I understand the page limit of the main body, there seems to be relatively less than enough content on the main results of this work in the main body. Perhaps consider moving more technical parts from the appendix into the main body.
  3. The structure of the paper could be reorganized a bit: e.g., the paragraph "Reducing dependence on nn" could be part of the discussions after presenting the full main results.

问题

  1. What do you think is the primary season/insight that the upper bounds for p\ell_p-distances are different between odd and even pp?
  2. Do you think it is possible to characterize the sample complexity bound when the cardinality of VV is infinite?
  3. Minor:
  • In the paragraph "Outline of the techniques" on Page 4, why "P is some polynomial of degree 2d2d"?
  • In the paragraph above "Outline of the techniques" on Page 4, (x1,x2)(x_1^-, x_2^-) -> (x3,x4)(x_3^-, x_4^-)?
评论

We thank the reviewer for their comments. We updated the PDF, which now includes all of the suggested changes. We would to highlight that:

  • as you correctly mentioned in the summary, our work provides distribution-independent assumptions, which separates our work from other related works that make various assumptions about the data, model, or training process;
  • our bounds are achieved using PAC-learning techniques, which, contrary to the widespread belief, demonstrates that PAC-learning techniques are applicable to the contrastive learning setting.

We also would be very grateful if you could clarify why soundness was estimated as “2. Fair”. Please let us know if you have any specific questions about the proof of our results.

Q: A thorough comparison with the known sample complexity bounds.

A: We first would like to emphasize that our setting is substantially different from all the previous work we are aware of: namely, we estimate the sample complexity without any assumptions on data distribution, classifier, or training algorithm. Moreover, multiple works measure the classifier quality using some continuous loss function instead of prediction accuracy (and hence ϵ\epsilon is formulated in terms of the loss function), which again makes the results incomparable. We below outline the works on sample complexity which are most related to ours:

  • “Sample complexity of learning Mahalanobis distance metrics” by N. Verma and K. Branson shows that learning a n×nn \times n Mahalanobis matrix requires O(nϵ2)O(\frac{n}{\epsilon^2}) samples.
  • “Tree Learning: Optimal Sample Complexity and Algorithms” by Avdiukhin et al. (AAAI 2023) is an example of the complexity bound for the specific metric case, showing that the VC dimension of tree learning is Θ(n)\Theta(n). The goal of tree learning is to build a hierarchy on nn points such that constraints of the form “cc is separated from aa and bb first” are satisfied. This is a metric case where the distance between aa and bb can be defined as the number of leaves under the least common ancestor of aa and bb.
  • Saunshi et al. (2019) provide sample complexity bounds for transfer learning settings.
  • A line of work considers the sample complexity of recovering metric embeddings using noisy queries in various settings. In particular, “Landmark Ordinal Embedding” by Ghosh et al. shows that O(d8nlogn)O(d^8 n \log n) noisy triplet queries suffice to achieve a constant additive divergence.

Q: While I understand the page limit of the main body, there seems to be relatively less than enough content on the main results of this work in the main body. Perhaps consider moving more technical parts from the appendix into the main body.

A: Thank you for the suggestion, we moved the proof of Theorem B.2 (constant dd case) to the main body. Given the page limit, while other technical results can’t be included in the main body, we give an outline of the proofs in the “Outline of the techniques”.

Q: The structure of the paper could be reorganized a bit: e.g., the paragraph "Reducing dependence on n" could be part of the discussions after presenting the full main results.

A: Thank you, we moved it to the “Extensions” section.

Q: What do you think is the primary reason/insight that the upper bounds for p\ell_p-distances are different between odd and even p?

A: The main reason that the upper bounds for odd pp are different from the upper bounds for even pp is that, for even pp, xypp\|x - y\|_p^p is a polynomial in x1,,xd,y1,,ydx_1,\ldots,x_d,y_1, \ldots, y_d, while for odd pp it is not, and it is not even a pp-times continuously differentiable. Intuitively, xypp\|x - y\|_p^p is a worse-behaving function for odd pp than for even pp, so it is not surprising that the upper bound is similarly worse.

Q: Is it possible to characterize the sample complexity bound when the cardinality of VV is infinite?

A: We briefly outline the idea in the paragraph “Reducing dependence on nn”. The following are natural scenarios.

  • In practice, in order to reduce nn one can use any unsupervised method which allows one to perform deduplication without substantially affecting the metric structure. For example, an extremely common assumption in the literature is that the data comes from knk \ll n different classes. In this case, one can consider using an unsupervised clustering algorithm (e.g. a pretrained neural network) to partition the points into kk clusters, effectively replacing nn with kk in all bounds.
  • Another practical scenario is that the distribution is supported on a large domain but contains a large number of low-probability outliers. In this case, one can replace nn with the support of 99% of the distribution at the cost of a small increase in error.

Q: "Outline of the techniques": why "P is some polynomial of degree 2d"?

A: Thank you, this is a typo - P is a polynomial of degree 2.

Q: (x1,x2)(x3,x4)(x_1^−,x_2^−) \to (x_3^−,x_4^−)

A: Thank you, fixed!

评论

Dear reviewer,

As the discussion period is concluding tomorrow, we would like to ask if our reply has addressed all of your questions about our work.

评论

Thanks for the responses. Yes, I think my questions about this paper are addressed, and I keep the positive score for this work.

审稿意见
8

This paper studies the sample complexity of contrastive learning, which learns the similarity (usually the distance in a metric space) between domain points, given tuples each labeling a most-similar input point to a given point (anchor).

This paper proves matching (or some almost matching) bounds on the sample complexity for contrastive learning of different metrics (arbitrary distance, cosine similarity, and tree metric), with generalization to learning with hard negatives (separated 2\ell_2 distance), quadruplet learning, or learning with kk negatives.

The results are based on an output-based assumption: there is an embedding into a dd-dimensional vector space. This enables reasoning on the VC/Natarajan dimension to study the PAC learning framework, both for the realizable and the agnostic cases, to get non-vacuous PAC-learning bounds with predictive powers.

The proof is on representing the decision boundary of contrastive learning under such metrics as a low-degree polynomial, and upper bounding its number of possible satisfiable sign changes (Lemma 3.5 proved in Section B), hence the largest shattered set of tuples, and VC/Natarajan dimension.

The theoretical result is also experimentally verified on popular image datasets (CIFAR-10/100 and MNIST/Fashion-MNIST), by learning the representations with a ResNet18 trained from scratch with different contrastive losses.

优点

The proof is by understanding the problem via transforming it to another equivalent representation: the decision boundary of contrastive learning under common metrics as a low-degree polynomial, and bounding its number of possible satisfiable sign changes. There is no loss until invoking the algebraic-geometric bounds on number of connected components by Warren, and the sample complexity bounds on Natarajan dimension by Ben David et al.

And at a high level, the paper shows that when the learned representation is expressive enough (such as ResNet18), PAC-learning bounds (e.g., by VC/Natarajan dimension) can have predictive powers.

The theoretical result is also experimentally verified on popular image datasets (CIFAR-10/100 and MNIST/Fashion-MNIST), by learning the representations with a ResNet18 trained from scratch with different contrastive losses (Appendix F).

缺点

The proof arguments are somewhat non-constructive, due to using a counting argument/pigeonhole principle, and hence while the theory may explain certain observations (e.g., the experimental results in Appendix F), it is unlikely to give effective learning algorithms and the constants in the resulting bounds are unlikely to be sharpened. This may be nitpicking, but are weaknesses nonetheless.

问题

While the experimental results (in Appendix F) verify the growth of error rates as predicted by the theory for ResNet18 on certain parameter ranges, there is not much explanations regarding, e.g., what representations (e.g., deep learning architectures) are expressive enough to achieve the sample bounds as predicted by the theory. That is, the representations (given by the theory) are somehow non-explicit/ineffective, is that the case?

Typos?

Statement of Theorem 1.3 (Page 4): The sample complexity of contrastive learning for contrastive learning under cosine similarity is...

评论

First, we would like to thank the reviewer for their helpful comments. We would like to add that one of the main points of our work lies in the fact that we provide sample complexity bounds without any assumptions on the data distribution or the learning algorithm. We also would like to emphasize that the focus of our work is sample complexity - namely, how many labeled examples are necessary to produce a good classifier, regardless of the model and the training algorithm. This question is important by itself since it allows one to determine whether the amount of data available is sufficient to train a good classifier.

Q: The proof arguments are somewhat non-constructive, due to using a counting argument/pigeonhole principle, and hence while the theory may explain certain observations (e.g., the experimental results in Appendix F), it is unlikely to give effective learning algorithms and the constants in the resulting bounds are unlikely to be sharpened. This may be nitpicking, but are weaknesses nonetheless.

A: The learning algorithm suggested by our work is Empirical Risk Minimization (ERM), i.e. minimizing the error on the sample data (which is optimal for statistical learning problems in general). While theoretically the ERM problem is computationally hard, any algorithm which performs well in practice can be used. Empirical success of deep learning architectures makes them a suitable choice for the problems we consider. The focus of our work is to quantify the amount of data required for fitting such architectures to the training data while providing generalization guarantees.

Q: While the experimental results (in Appendix F) verify the growth of error rates as predicted by the theory for ResNet18 on certain parameter ranges, there is not much explanations regarding, e.g., what representations (e.g., deep learning architectures) are expressive enough to achieve the sample bounds as predicted by the theory. That is, the representations (given by the theory) are somehow non-explicit/ineffective, is that the case?

A: One property of the neural network architecture that we can analyze is the dimension of the embedding layer, denoted as dd. In particular: Using our results, given the number of available samples, one can estimate whether they are sufficient for achieving low generalization error. In particular, our work shows that for networks with a Euclidean embedding layer of dimension dd, approximately ndnd samples are necessary and sufficient where n is the number of different points in the domain. Furthermore, in the well-separated case we show that the number of samples required is independent of the dimension of the embedding layer.

审稿意见
8

Given labeled sample (x1,y1+,z1),,(xn,yn+,zn)(x_1,y^+_1,z^-_1),\ldots,(x_n,y^+_n,z^-_n), the goal of contrastive learning is to create a distance function ρ\rho such that ρ(x,y)<ρ(x,z)\rho(x,y) < \rho(x,z). This study comes in theoretical flavour, providing lower and upper bound for sample complexity of contrastive learning via PAC-learning framework. The main ingredient of the proof of the bound is the Natarajan dimension (which is a generalization of the VC dimension) and the results from Ben David et al. (1995).

Reference:
S. Bendavid, N. Cesabianchi, D. Haussler, P.M. Long, Characterizations of Learnability for Classes of {0, ..., n)-Valued Functions, Journal of Computer and System Sciences, Volume 50, Issue 1, 1995, Pages 74-86.

优点

  • The study covers a wide range of distance functions, both in lower bounds and upper bounds.
  • Interesting use of an algebraic geometry result to prove the upper bound.
  • The theory is well-supported by the empirical results.
  • The authors discuss possible directions for future work.

缺点

Personally, I find it hard to follow the results by chapter. For example, Section 3 should be all about the bounds in p\ell_p norm, and so Theorem 3.1 (Arbitrary distance) should come before this section. And I think it would be easier to follow if Section 3 is only for k=1k=1 and Section 4 is only for k>1k>1.

I have some comments for Table 1:

  • I would put "(1+α)(1+\alpha)-seperable 2\ell_2" in the same category as p\ell_p.
  • The "Same" labels for quadruplet learning and kk negatives have different meaning; while the former refer to all distance functions above, the latter only refers to the p\ell_p distances. Could the authors modify the table so that this distinction becomes clearer?

问题

See Weaknesses.

评论

We thank the reviewer for their comments. We updated the PDF, which now includes all of the suggested changes. In addition to the strengths of our work you identified, we would like to highlight the following important aspects of our work:

  • our work shows that, contrary to the widespread belief, PAC-learning can provide useful bounds for sample complexity in deep learning, specifically in the contrastive learning setting;
  • our work provides bounds that are independent of any assumptions on the input or the learning algorithm.

Q: Personally, I find it hard to follow the results by chapter. For example, Section 3 should be all about the bounds in p\ell_p norm, and so Theorem 3.1 (Arbitrary distance) should come before this section. And I think it would be easier to follow if Section 3 is only for k=1k=1 and Section 4 is only for k>1k>1.

A: Thank you, we moved Theorem 3.1 to the previous section. We moved the results from subsection 3.1 to a separate section.

Q: I would put "(1+α\alpha)-separable 2\ell_2" in the same category as p\ell_p.

A: Fixed

Q: The "Same" labels for quadruplet learning and k negatives have different meaning; while the former refer to all distance functions above, the latter only refers to the ℓp distances. Could the authors modify the table so that this distinction becomes clearer?

A: Thank you for pointing this out. We’ve clarified in Table 1 that our results for k negatives in fact cover all distances and not just ℓp (the results are in Appendix C.1).

AC 元评审

This paper was uniformly appreciated by all reviewers, and provides novel insights into the important problem of contrastive learning. Clear accept.

为何不给更高分

This could be an oral.

为何不给更低分

All reviewers clearly supported the paper.

最终决定

Accept (spotlight)