On the Out-of-Distribution Generalization of Self-Supervised Learning
摘要
评审与讨论
This paper focuses on the out-of-distribution generalization of self-supervised learning. The authors first give one plausible explanation for SSL having OOD generalization, then analyze and conclude that SSL learns spurious correlations during the training process from the perspective of generation and causal inference. To address this issue, they further propose a post-intervention distribution (PID) grounded in the Structural Causal Model. Experiments verify the advantages of their method.
Update After Rebuttal
Thank you for addressing my concerns. I maintain my positive score.
给作者的问题
- Could you provide more detailed insights into how the Post-Intervention Distribution (PID) is specifically integrated into the self-supervised learning process? It would be beneficial to understand the operational steps or algorithms used to enforce PID constraints during mini-batch preparation.
- You mentioned that the proposed method minimizes spurious correlations. Can you discuss any specific metrics or evaluation criteria used to measure the extent of spurious correlations before and after applying your method?
- Could you discuss the scalability of your proposed method, particularly in terms of computational resources and time required as the dataset size increases? Is the method feasible for large-scale real-world applications where computational efficiency is critical?
论据与证据
It seems convincing.
方法与评估标准
The method and evaluation criteria make sense.
理论论述
I have checked the proofs of theoretical claims, and the entire theoretical derivation and claims seem appropriate.
实验设计与分析
I have checked the experimental designs and analyses, and the experimental results are impressive.
补充材料
I have reviewed the supplementary.
与现有文献的关系
By analyzing the mini-batch construction during the SSL training phase, this paper gives one plausible explanation for self-supervised learning (SSL) having OOD generalization. Moreover, this paper also analyzes and concludes that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization.
遗漏的重要参考文献
The essential references seem sufficient.
其他优缺点
Strengths:
- This article provides a wealth of theoretical analysis, making the entire work more solid.
- The experimental results are impressive.
Weaknesses:
- While the theoretical aspects are robust, the practical implementation of these concepts, especially the integration of causal inference in SSL, might be complex and computationally intensive. This could limit its applicability in environments with constrained computational resources.
- Some of the causal assumptions made may not hold in all real-world scenarios, which could affect the generalizability of the findings. A deeper exploration of these assumptions, including conditions under which they may not be valid, would provide a more comprehensive view of the method’s applicability.
- Some key terms and variables used throughout the paper could be defined more clearly to avoid ambiguity, enhancing the paper’s accessibility to a broader audience.
其他意见或建议
See weaknesses (above) and questions (below).
Response to Weaknesses 1 & Questions 3:
Thank you for pointing these out. The proposed method has two main phases with the following complexity analysis per mini-batch (batch size , dataset size ):
Step 1: Latent Variable Model Training:
- : Each sample requires a forward pass with cost , totaling .
- : Each sample incurs a cost , totaling .
- for : Computed once per mini-batch with cost .
- KL-Divergence: Involves operations over the latent dimension and sufficient statistic dimension , contributing .
- Orthogonality Regularization: Requires , which is constant when and are small.
Thus, the training phase complexity is approximately:
Step 2: Algorithm 1
- Propensity Score Calculation: For each sample, computing scores across the candidates costs , leading to a total of for the mini-batch.
- Matching Operation: A brute-force matching over samples yields an additional .
Therefore, the sampling phase has an overall complexity of approximately:
Step 3: Overall Complexity
The combined complexity per mini-batch is:
The symbols , , and represent the computational cost for a single forward pass (or operation) of each respective network module. For specific computational resources and time, please refer to Response to Weaknesses 2 of Rebuttal for Reviewer 1Q1V
Response to Weaknesses 2
Thank you for pointing this out. We provide a deeper exploration of Assumption 3.3 and Assumption 4.1 in our response.
For Assumption 3.3, we implicitly assume that the noise is independent of both and . However, in many practical scenarios, noise may be correlated with either the latent variables or the observed features—for example, sensor noise that correlates with lighting conditions in image data—which can interfere with the separation between causal and non-causal factors.
Regarding Assumption 4.1, the main concern lies in the potential mismatch between the true conditional distribution in real-world data and the assumed exponential family. If the actual distribution is more complex or exhibits behaviors that go beyond this family—such as multi-label or multi-instance characteristics—then the applicability of our method may be compromised.
Response to Weaknesses 3:
Thank you for pointing this out. In the final version, we will add a table to illustrate all terms and variables related to our method.
Response to Questions 1:
Thank you for pointing this out. We explain this issue through the following steps:
Step 1: How do we implement PID
According to Definition 4.4 in the original submission, and are conditionally independent given . Based on this, if all pairs in a mini-batch share the same , then within this mini-batch, and can be considered independent. Consequently, such a mini-batch can be viewed as being sampled from a PID.
Step 2: How is this integrated into SSL
In the training phase of SSL, a mini-batch is typically sampled from the training data prior to each iteration. In standard SSL, this mini-batch is randomly sampled. In contrast, our method constructs the mini-batch using Algorithm 1 from the original submission. That is, our approach embeds into SSL by replacing the mini-batch sampling process with Algorithm 1, without altering any other part of the SSL training procedure.
According to Algorithm 1, the core criterion for selecting samples is to ensure that the values of each pair are as similar as possible. This ensures that the resulting mini-batch has consistent across all samples, thereby forming a PID.
Response to Questions 2:
Thank you for pointing this out. Instead of proposing specific metrics or evaluation criteria, we run a toy experiment on the COCO dataset [1] with two different experimental settings: 1) training and testing the SSL model on full images; 2) training and testing the SSL model on foreground images. Setting 2) can be thought of as not being subject to background semantic confounding. For the Top 1 classification accuracy, the results of SimCLR are 39.66 and 50.19, the results of SimCLR + Ours are 45.25 and 51.48. We observe that our method gives closer results in both settings and significantly outperforms SimCLR. Thus, it can be concluded that our method learns less spurious correlations.
[1] Microsoft coco: Common objects in context. ECCV, 2014.
The paper explores the out-of-distribution (OOD) generalization of self-supervised learning (SSL). It analyzes how mini-batch construction in SSL training influences OOD generalization and argues that SSL models often learn spurious correlations, which hinder their ability to generalize to unseen distributions. To address this issue, the paper introduces a post-intervention distribution (PID) based on Structural Causal Models (SCMs). This ensures that spurious variables and label variables remain independent, improving OOD generalization.
Furthermore, the authors propose a mini-batch sampling strategy that enforces PID constraints through a latent variable model. They provide theoretical proof of the identifiability of their method and validate it with empirical results. Experiments on various downstream OOD tasks demonstrate that their approach significantly enhances SSL’s generalization performance.
Strengths:
-
Novel Causal Perspective on SSL OOD Generalization
- The paper offers a compelling causal analysis of why SSL struggles with OOD generalization and how spurious correlations arise.
-
Innovative Mini-Batch Sampling Strategy
- Unlike traditional batch sampling, the method ensures spurious correlations are minimized, leading to better OOD generalization.
-
Strong Empirical Performance
- The proposed method consistently improves performance across diverse benchmarks, including unsupervised, semi-supervised, transfer learning, and few-shot learning tasks.
给作者的问题
Please refer to Other Strengths And Weaknesses part.
论据与证据
The claims in the paper are partially supported by the evidence in the experiment part. However, the reviewer is concerned about the lack of evaluation on mask-autoencoder based pre-training methods like [1,2] in the main paper. Although some results are provided in the supplementary, it can be worthwhile to add the comparison in the experiments in the main paper together with contrastive-based methods, as in the analysis and proof part, the authors formulate discriminated-based and generative-based methods with a unified framework.
[1] He, K., Chen, X., Xie, S., Li, Y., Doll´ ar, P., and Girshick, R. Masked autoencoders are scalable vision learners. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 16000–16009, 2022. [2] Tong, Z., Song, Y., Wang, J., and Wang, L. Videomae: Masked autoencoders are data-efficient learners for self-supervised video pre-training. Advances in neural information processing systems, 35:10078–10093, 2022.
方法与评估标准
The reviewer thinks the evaluation setting and metrics make sense for the claims.
理论论述
The reviewer checked partially of the proofs. Specifically, theoretical claims in Section 3 and 4.1 are checked.
实验设计与分析
The reviewer thinks that experimental designs are valid.
补充材料
Yes. Section BCE were checked. And part of Section A was checked.
与现有文献的关系
The paper makes a theoretical and practical contribution by introducing a causal approach to improving OOD generalization in SSL. While the method is empirically validated, its assumptions, computational cost, and feasibility in large-scale applications could be explored further.
遗漏的重要参考文献
The reviewer does not come up with essential related literature that is not discussed.
其他优缺点
- Starting from Line 162, left column, the authors propose an assumption that "the semantic information within x+ is related only to xlabel, that is, s does not contain any causal semantics related to the task.". They provide two examples about this assumption. However, as natural images are not restricted to numbers/styles discussed in the two example, the reviewer is concerned about the assumption. More examples in ImageNet should be provided against it.
- As additional training is required, the reviewer is concerned about the training efficiency of the proposed method. The authors should provide evaluation on it.
- As more parameters (VAE) are introduced in the proposed method, the reviewer is concerned about the fairness of comparison. Some discussion should be provided.
其他意见或建议
Please refer to Other Strengths And Weaknesses part.
Response to Claims And Evidence:
Thank you for pointing this out. In Appendix C.1, we report the results of MAE. Now, we present the results of VideoMAE.
We transfer the learned VideoMAE + Ours on Kinetics-400 [1] to downstream action detection dataset AVA [2]. Following the standard setting [3], we evaluate on top 60 common classes with mean Average Precision (mAP) as the metric under IoU threshold of 0.5. The result is shown as follows:
| Method | Backbone | Pre-train Dataset | Extra Labels | GFLOPs | Param | mAP | |
|---|---|---|---|---|---|---|---|
| VideoMAE | ViT-L | Kinetics-700 | No | 597 | 305 | 36.1 | |
| VideoMAE | ViT-L | Kinetics-700 | Yes | 597 | 305 | 39.3 | |
| VideoMAE + Ours | ViT-L | Kinetics-700 | No | 597 | 305 | 38.7 | |
| VideoMAE + Ours | ViT-L | Kinetics-700 | Yes | 597 | 305 | 42.1 |
In the above table, "Extra Labels" denotes that if the pre-trained models are additionally fine-tuned on the pre-training dataset with labels before transferred to AVA. refers to frame number and corresponding sample rate. In the final version, we will add these results to the main body of our submission.
[1] The kinetics human action video dataset. arXiv preprint, 2017.
[2] Ava: A video dataset of spatio-temporally localized atomic visual actions. CVPR, 2018.
Response to Weaknesses 1:
Thank you for pointing this out. We provide some additional ImageNet-inspired examples and explanations to address the concern:
Object vs. Background: Consider an ImageNet class like "Golden Retriever." The causal semantics for recognizing a Golden Retriever primarily reside in the dog’s shape, fur texture, and facial features. Although many images might have different backgrounds—such as parks, beaches, or urban settings—these background elements (which could be captured by ) are not causally responsible for the image being classified as a Golden Retriever. In this case, would capture the object-specific features, while would account for non-causal variations like the background.
Intra-Class Variability: Take another class, such as "Volcano." A volcano can be pictured under different weather conditions, from different angles, and with various surrounding landscapes. While these environmental or stylistic factors vary widely, the key causal semantics—such as the volcano’s structure, cone shape, and crater—remain consistent. Again, may vary (capturing changes in lighting, weather, or background) without affecting the causal information needed to identify a volcano.
Response to Weaknesses 2:
Thank you for pointing this out. We provide the model efficiency and memory footprint results of the proposed method trained on 8 NVIDIA Tesla V100 GPUs.
| Method | Training Time (Hours) | Memory Footprint (GB) | ||
|---|---|---|---|---|
| CIFAR-10 | ImageNet | CIFAR-10 | ImageNet | |
| SimCLR | 10.4 | 101.9 | 23.3 | 221.6 |
| SimCLR+Ours | 12.7 | 106.2 | 29.7 | 230.7 |
| MAE | 13.8 | 115.5 | 26.9 | 244.9 |
| MAE+Ours | 16.4 | 122.2 | 31.2 | 252.2 |
For specific computation complexity, please refer to Response to Weaknesses 1 & Questions 3 of Rebuttal for Reviewer fDza
Response to Weaknesses 3:
Thank you for pointing this out. To illustrate without loss of generality, we take SimCLR as a representative SSL method. First, our proposed Algorithm 1 only modifies the mini-batch construction process during the training phase of SimCLR. Even though we train a VAE, it does not affect other components of SimCLR’s training pipeline, including the training objective, network architecture, and optimization algorithm. Second, training a VAE independently on ImageNet and using its feature extractor for evaluation yields an accuracy of 35.44%, in contrast to SimCLR's 70.15%. We then use the parameters of the VAE's feature extractor to initialize the feature encoder of SimCLR and retrain SimCLR from this initialization. The resulting accuracy is 68.71%, which is 1.45% lower than that of SimCLR trained from scratch. In comparison, SimCLR combined with our method achieves an accuracy of 73.32%. These results demonstrate the fairness of our evaluation and confirm that the performance gain is not due to the additional VAE.
This work propose a minibatch sampling strategy to select pairs of samples in the mini-batch to enhance the OOD geralization ability of SSL methods. By investigating on a causal perspective from the constructed SCM model, the method propose a Post-Intervention Distribution, which can be realized by balancing score.
给作者的问题
- Why using a exponential family distribution to model ? Why not using a reversible neural nets which may achieve higher expressibity and simplify the design?
论据与证据
I do not find any evident errors in the claims.
方法与评估标准
The method seems convincing, essentially when the balancing condition holds, there is no changes of distribution in spurious features, therefore the SSL method will focus on the invariant features.
理论论述
The results of Theorem 3.4 should be correct that the loss minimizes the worst-case risk, this is a well-defined target in invariant learning literatures. I am not sure if Theorem 4.3 is correct as I am not familiar with the identifiability theory.
实验设计与分析
This work primarily addresses the OOD generalizability of SSL methods; however, the experiments do not include any OOD datasets, such as Waterbirds and CMNIST. Conducting experiments directly on OOD datasets would help evaluate the effectiveness of the proposed sampling strategy.
补充材料
I read Appendix E and F.
与现有文献的关系
This work relates to the self-supervised learning literatures and domain generalization, as well as causal inference literatures.
遗漏的重要参考文献
No.
其他优缺点
- The biggest concern for me is the experiment setting, i.e., it does not involve any OOD dataset for evaluation, however, the main goal of this work is to enhance OOD generalization ability of SSL methods.
其他意见或建议
N/A
Response to Weaknesses 1 & Experimental Designs Or Analyses:
Thank you for pointing this out. We clarify this issue through the following steps:
Step 1: How the original submission constructs the OOD task
The transfer learning task and the few-shot learning task can be regarded as OOD (out-of-distribution) tasks, as the training and test datasets in these tasks follow different data distributions. Meanwhile, in Appendix C.3 of the original submission, we also provide evaluation results on two OOD datasets, namely the Colored-MNIST dataset and the PACS dataset.
Step 2: Results on Waterbirds dataset and CMNIST dataset
For Waterbirds, we follow the implementation in ZARE et al. (2023) "Evaluating and Improving Domain Invariance in Contrastive Self-Supervised Learning by Extrapolating the Loss Function". During training, waterbirds (landbirds) predominantly appear on water (land) backgrounds; however, the distribution is altered at test time. We report both the average and worst-group performance. The results in the table below demonstrate that our method yields consistent improvements, particularly enhancing performance for the worst-performing group.
| Method | Test Accuracy | Worst Group |
|---|---|---|
| SimCLR | 76.2 | 19.2 |
| SimCLR+Ours | 78.0 | 24.9 |
| MAE | 74.9 | 17.6 |
| MAE+Ours | 77.2 | 22.1 |
For Colored-MNIST, we follow the implementation in Huang et al. (2024) "On the Comparison between Multi-modal and Single-modal Contrastive Learning". The task is a 10-class digit classification, with 10% of the labels randomly reassigned. During training, images belonging to class ‘0’ (or ‘1’) are colored red (or green) with a probability of 77.5%, and another random color with a probability of 22.5%. For the test set, the coloring scheme is reversed relative to training, which allows us to evaluate the extent to which the model relies on color cues for classification. The results in the table below show that our method improves OOD test accuracy by nearly 10%.
| Method | Test Accuracy |
|---|---|
| SimCLR | 12.7 |
| SimCLR+Ours | 23.5 |
| MAE | 15.1 |
| MAE+Ours | 24.9 |
These results further demonstrate that our proposed method effectively improves the OOD generalization performance of SSL.
Response to Question 2:
Thank you for pointing this out. According to Assumption 3.3 in the original submission, both and can be obtained through an invertible neural network. Training such an invertible neural network typically requires training data in the form of . However, we did not adopt this mechanism directly because we did not have access to such training data. In particular, we were unable to provide the corresponding for each pair. Therefore, we opted to use a Learning Latent Variable Model approach instead.
This paper explores whether self-supervised learning possesses out-of-distribution (OOD) generalization capabilities and investigates the reasons behind its potential failure. To address this, the authors propose a Post-Intervention Distribution (PID), grounded in the Structural Causal Model. PID enables accurate OOD generalization by disentangling spurious correlations between features and labels. The authors introduce a simple yet seemingly effective mini-batch resampling technique and provide a substantial number of supporting theorems. However, I find the number of theorems somewhat excessive. I suggest that the authors consolidate the most essential ones into key theorems and present them in the manuscripts for better clarity and impact.
给作者的问题
see weakness
论据与证据
yes
方法与评估标准
yes
理论论述
yes
实验设计与分析
yes
补充材料
yes, all part
与现有文献的关系
This paper explores the relationship between self-supervised learning and OOD generalization.
遗漏的重要参考文献
yes
其他优缺点
Strengths:
- Comprehensive experiments and theoretical justifications.
- The proposed method is simple yet appears to be effective.
Weaknesses:
-
The authors should include additional OOD benchmark datasets in the main experiments, such as Colored-MNIST and PACS.
-
I find the concept of mentioned in line 90 somewhat confusing, particularly regarding why it can be directly transformed into the used in Equation (1).
-
How is the balancing score function specifically implemented? Is it learnable? Does it produce a scalar output?
其他意见或建议
see weakness
Response to Weaknesses 1:
Thank you for pointing this out. Due to space limitations, we reported the experimental results of Colored-MNIST and PACS in Tables 9 and Tables 10 in Appendix C.3 of the original submission. In the final version, we will move these results to the main body of the paper.
Response to Weaknesses 2:
Thank you for pointing this out. We explain this issue through the following steps:
Step 1: How are augmented data pairs formed in SSL
In D-SSL, each sample in a mini-batch undergoes stochastic data augmentation to generate two augmented views, e.g., for , the augmented samples can be represented as and . For G-SSL, is first divided into multiple small blocks, with some blocks masked, and the remaining blocks reassembled into a new sample, denoted as . The original sample is then referred to as . Thus, the augmented dataset in SSL (whether D-SSL or G-SSL) is represented as . forms the -th pair.
The above statement can be found in the first paragraph of Section 2 in the original submission.
Step 2: How is the anchor formed in SSL
The objective of D-SSL methods typically consists of two components: alignment and regularization. The alignment part is to maximize the similarity between samples that share the same pair in the embedding space, and the regularization part aims to constrain the learning behavior via inductive bias. It is noteworthy that “alignment” in D-SSL is often implemented based on anchor points, that is, viewing one sample in a pair as an anchor, the training process of such SSL methods can be seen as gradually pulling the other sample in this pair (a pair consists of two augmented samples) towards the anchor. Meanwhile, G-SSL can be regarded as implementing alignment of samples within a pair based on an encoding-decoding structure, by inputting sample into this structure to generate a sample, and making it as consistent as possible with sample . The concept of anchor is also applicable to G-SSL, where is viewed as the anchor, and thus the training process of such SSL methods can be viewed as gradually constraining to approach .
The above statement can be found in the second paragraph of Section 2 in the original submission.
Step 3: How does the anchor become a label in SSL
Based on Step 2, regardless of whether it is G-SSL or D-SSL, the anchor can be regarded as a learning target. Specifically, SSL can be interpreted as follows: In a data augmentation pair, one sample (the anchor) is designated as the target. By constraining the other augmented sample in the feature space to move toward this anchor, consistency in feature representations is achieved. This dynamic adjustment causes samples within the same pair to become tightly clustered, thereby forming an effect similar to a local cluster center.
In traditional classification problems, the common approach is to first project samples into a label space and then constrain them to move toward their corresponding one-hot labels to achieve supervision. In contrast, SSL directly applies constraints in the feature space, which means that the anchor effectively takes on the role of a “label.” In this unsupervised setting, the anchor provides a supervisory signal similar to that of a label. Therefore, it can be argued that labels manifest differently across various spaces—in the feature space, the anchor represents this “implicit label.”
Response to Weaknesses 3:
Thank you for pointing this out. We clarify this issue through the following steps:
Step 1: How we obtain for a given pair
is a vector obtained based on the distribution . In other words, when is given, is also determined. The learning process of is described in Section 4.1 of the original submission. For a given pair, e.g., , we sample once from to obtain .
Step 2: The computation of in Algorithm 1
Based on Equation (5) in the original submission, we compute . Specifically, this computation is performed with respect to the entire dataset. Given the full training dataset, the pair , and , the involved in Equation (5) are traversed across the entire dataset. Then, according to Definition 4.5, we obtain , which is a vector.
Step 3: High-level explanation of Algorithm 1 and the identifiability of the spurious variable
A high-level explanation of Algorithm 1 is provided in Appendix F, while a high-level explanation regarding the identifiability of the spurious variable is provided in Appendix G.
The reviewers generally agree that this paper should be accepted. I concur, albeit on the margins. The reviewers found the theory provided to be solid but raised some questions on both the range of the empirical validation and the generalizability of the proposed method, given notable additional computational needs.
Strengths. The theoretical motivations of the work are robust and comprehensive.
Weaknesses. The proposed method requires some unverifiable assumptions that may not fully hold in practice. The empirical results give some cover for this, but could be expanded. The proposed method also requires notably more computation resources, which introduces questions about tradeoffs between compute and added performance.
The authors provided clarifications and additional OOD dataset experiments and computational estimates. The set of OOD datasets could be expanded, e.g., Domainbed/WILDS datasets. However, the strengths of the paper outweigh the weaknesses.