Causal Representation Learning and Inference for Generalizable Cross-Domain Predictions
We propose a causal representation learning framework based on a novel SCM
摘要
评审与讨论
The authors propose a domain generalization algorithm motivated by causality. They specify two latent variables and whose marginals can shift between training and testing. Their approach aims to become invariant to the aforementioned latent variables by intervening on their common child , which closes the backdoor path between the input and target.
优点
The authors tackle the important problem of DG, and their approach shows promising empirical results. The paper is well-written and clearly motivated. Also, their approach is interesting in that unlike many existing DG algorithms, theirs doesn't use the environment labels.
缺点
I found three technical issues with the paper. One is a major issue, and two are minor.
-
(Major) The algorithm does not perform its stated purpose of being invariant to shifts in and . The predictive distribution in Eq. (2) is not invariant to and , since it involves an expectation over , which can shift across training and testing.
-
(Minor) The posterior is assumed to factorize , which is at odds with the assumed causal graph.
-
(Minor) The authors cite Kivva 2022 to claim that their standard normal prior is a one-component Gaussian mixture, and therefore is identifiable (along w/ the piecewise affine decoder assumption). Calling a standard normal distribution a Gaussian mixture is technically true, but this identifiability argument is a bit tenuous.
问题
Please address my three points above in the "weaknesses" section.
Factorized variational distribution : While we acknowledge that the true distribution cannot be factored under our SCM assumption, it often becomes intractable due to the conditions that only and are observed in the SCM. The distribution is likely to be complex and non-Gaussian. As an alternative solution, we approximate by a variational distribution with Gaussian assumptions. Although using non-factorized Gaussians might slightly enhance results, it requires more time to estimate the large covariance matrix. There is always a balance between accuracy and efficiency. We believe our assumption, commonly used in many VAEs, is reasonable and effective. Moreover, our empirical results indicate that the factorized Gaussian approximation effectively leads to better OOD generalization. We appreciate your suggestion and will include an analysis of using non-factorized Gaussians in our revised paper.
Identifiability: We emphasize that our proof of identifiability, leveraging the results of [1], is mathematically correct. However, it is important to note that this proof is not our main contribution. Our focus is on designing conditional distributions in our SCM to ensure that the learned latent representations are identifiable. While Kivva et al. have proven identifiability in a variety of cases, our work necessitates only a specific instance to effectively learn our SCM. We model as a standard Gaussian instead of a mixture of Gaussian, chosen for its simplicity and because we lack additional prior information about . This assumption aligns with the prior distributions commonly used in many VAEs.
[1] Kivva, Bohdan, et al. "Identifiability of deep generative models without auxiliary information." Advances in Neural Information Processing Systems 35 (2022): 15687-15701.
Invariance of : To address the reviewer’s concerns, we will first emphasize and rectify the assumptions we make regarding the latent variables . Then we will explain that under these assumptions, is invariant and transportable across domains. Finally, we will explain how to obtain values for calculating and why it is a reasonable and good choice we have to approximate an invariant transformation between and .
-
Assumptions: Upon careful review, we acknowledge that our assertion regarding the variation of was overstated. It is imperative to make the confounding effects between the source domain and target domain consistent to make the proposed interventional distribution invariant and transportable. Hence we correct our initial assumption that the confounding effects, encompassing and the causal mechanisms between , , and , remain consistent across training and test domains. Moreover, we assume the distribution of varies across domains. It results in the variation of , and further the variation of . However, we emphasize that the generative mechanism remain invariant across domains ( is independent of given ). Otherwise, it is impossible to infer from any unseen domain.
-
Invariance of : Our proposed interventional distribution, by setting specific values for to mitigate the influence of the domain-specific variable on , effectively addresses the invariant confounding effects and prevent the from influence . Hence, it is invariant and transportable across domains.
-
The choice of : Under our corrected assumption, the latent confounding effect remains invariant across domains. To infer the label for a test input using our interventional distribution, it is imperative to provide the true value of for to accurately account for the confounding effects between and . The distribution utilized in Eq. (3) can be construed as a proxy distribution that enables us to derive the true values of from an input . The challenge lies in determining which learned distribution can be employed to approximate this proxy distribution. According to the SCM, we observe that is independent of given , rendering invariant and transportable. Therefore, a reasonable and good choice is to obtain the identifiable from the learned variational distribution with a high . However, as represents a variational estimation of the desired proxy distribution, we average the interventional distribution over multiple samples of from . In practice, randomly obtaining a value is likely to yield a low and contribute minimally to the calculation of the interventional distribution.
Nevertheless, it's crucial to underscore that with the corrected assumption, our CIIRL remains innovative and has proven its efficacy across various benchmark distribution shift datasets: 1) With the revised assumptions, our proposed interventional distribution maintains invariance and transportability, facilitating cross-domain inference. 2) Our training procedure does not explicitly rely on the variation of . The distinct classes of can be associated with the diverse domain variable . Our training procedure can still effectively distinguish the two types of representations. 3) Empirical results affirm the existence of invariant latent confounders, as evidenced by the overall superior OOD prediction performance of CIIRL compared to predictions solely based on .
This paper aims to solve the problem of out-of-distribution classification using a causal approach. In the problem setting, the features are caused by causal latent variables and spurious latent variables and are correlated with labels through both sets of latent variables. A typical classifier predicts , using the correlation through both and . However, under distribution shift, the distribution of unobserved variables affecting are changed, so using the spurious latent variables for classification can result in incorrect predictions out-of-distribution. Instead, the paper proposes using for classification, which severs the correlation between and through via a causal intervention, thus providing a quantity that is invariant across domains. Estimating this quantity requires learning encoders which map to and , a decoder which maps and back to , and a classifier . This is done by optimizing over a variational bound on the log-likelihood of the data. After training, predictions are obtained by computing a linear combination of predictions from weighted by a value indicating the compatibility of with (using Monte Carlo sampling to estimate expectations). Experiments demonstrate the effectiveness of the approach.
优点
This paper offers a novel take on leveraging causality to solve out-of-distribution classification. To my knowledge, there are no works which consider modeling the problem as done in Fig. 1, where is used as the classifier. The problem setup has interesting implications in terms of the ways that features and label are related. The experimental results also show promise that the approach is effective in practice.
缺点
I am concerned about the soundness of some of the claims:
-
The path from to is not influenced by any intervention on . Hence, if , it should also be the case that . This seems to contradict what is stated at the end of Sec. 3.1.
-
It is not clear how calculating the expectation of over (as done so in Eq. 2) is considered marginalizing out . It is also not clear why this is preferable to just choosing some arbitrary to intervene.
-
How are and modeled in Eq. 3 if they are unobserved and change between source and target?
-
What justifies that the learned representations and truly follow the causal diagram in Fig. 1? Given the generative process of learning these representations (i.e. through and ), it could be argued that and are caused by rather than the other way around. Further, it is difficult to believe that a learned representation can contain more information about than , but this is what is implied by the graph (i.e. and are independent given and but is not independent of and given ?).
In addition, there are a few points that could use more elaboration:
-
At the beginning of Sec. 3.1, it is explained that the consideration of and address two types of biases: selection bias and stereotype bias. This seems to be an interesting point and could be expanded.
-
Under Alg. 1, the paper mentions the necessity of assumptions to compensate for the lack of observations of and . These should be explicitly stated, as this seems to be the crux of the reasoning behind why the model works. Further, are some of these assumptions only relevant to certain types of data (e.g. images)?
I cannot recommend acceptance while I have these doubts, but I look forward to having them clarified in the authors’ responses.
问题
See weaknesses.
The SCM assumption: We would like to emphasize that the causal mechanisms in the proposed SCM in Figure 1 are all assumptions that are widely adopted in the area of causal representation learning [1, 2], including the three following points:
-
- The latent high-level factors can be separated into causal factors and spurious factors .
-
- The input is generated by the the high-level factors .
-
- Causal factor is either direct cause or effect of target .
We cannot prove these assumptions always hold in real-world data and we admit the effectiveness and soundness of our derived theorem and algorithm is built upon these assumptions. We believe the learned representation and are the factors that satisfy the causal graph by
-
- parameterizing the joint distribution regarding all the variables of interest into conditional distributions adhering to the causal mechanisms in Figure 1;
-
- establishing the identifiability of the learned .
We appreciate the suggestion of further elaboration of these two types of biases and will revise the paper accordingly.
We acknowledge the recommendation to underscore the importance of assumptions and will incorporate this emphasis into our paper accordingly. It's worth noting that our method is not confined solely to image data. The validity of our theorem and the effectiveness of the algorithm persist as long as the assumptions we posit apply to the given data. As an illustration, our approach can be extended to encompass text data, demonstrating the versatility of our framework.
Invariance of : Upon careful review, we acknowledge that our assertion regarding the variation of was overstated. Our proposed interventional distribution, by setting specific values for to mitigate the influence of the domain-specific variable on , effectively accounts for the invariant confounding effects and prevents the from influencing . Hence, it is invariant and transportable across domains. However, it is important to note that this framework may not generalize to new domains with different and unknown confounding effects.
In light of this, we correct our initial assumption that the confounding effects, encompassing and the causal mechanisms between , , and , remain consistent across training and test domains. This invariant confounding assumption aligns with standard practices widely adopted in works utilizing interventional distributions [1]. We will incorporate such a revision into the paper.
Nevertheless, it's crucial to underscore that with the corrected assumption, our CIIRL remains innovative and has proven its efficacy across various benchmark distribution shift datasets:
-
- With the revised assumptions, our proposed interventional distribution maintains invariance and transportability, facilitating cross-domain inference.
-
- Our training procedure does not explicitly rely on the variation of . The distinct classes of can be associated with the diverse domain variable . Our training procedure can still effectively distinguish the two types of representations.
-
- Empirical results affirm the existence of invariant latent confounders, as evidenced by the overall superior OOD prediction performance of CIIRL compared to predictions solely based on .
[1] Mao, Chengzhi, et al. "Causal transportability for visual recognition." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
The choice of : Under our corrected assumption, the latent confounding effect remains invariant across domains. To infer the label for a test input using our interventional distribution, it is imperative to provide the true value of for to accurately account for the confounding effects between and . The distribution utilized in Eq. (3) can be construed as a proxy distribution that enables us to derive the true values of from an input . The challenge lies in determining which learned distribution can be employed to approximate this proxy distribution. According to the SCM, we observe that is independent of given , rendering invariant and transportable. Therefore, a reasonable and good choice is to obtain the identifiable from the learned variational distribution with a high . However, as represents a variational estimation of the desired proxy distribution, we average the interventional distribution over multiple samples of from . In practice, randomly obtaining a value is likely to yield a low and contribute minimally to the calculation of the interventional distribution.
The modeling of and : Firstly, we revise our assumption by rectifying the notion that the distribution of remains invariant across domains. In our context, denotes any information specific to the domain, with a common simplification being to assume that represents the domain index [2]. During the training procedure, we utilize a clustering algorithm to estimate the domain index for each training input. The objective of the training process is to acquire the encoder distributions that produce a disentangled representation, which is identifiable and possesses an invariant generative mechanism . The interventional distribution accommodates confounding effects and mitigates the influence of the domain variable . Consequently, we can make inferences from the interventional distribution without knowledge of and for the target domain.
[2] Lu, Chaochao, et al. "Invariant causal representation learning for out-of-distribution generalization." International Conference on Learning Representations. 2021.
The work proposes a causal representation learning procedure for domain generalization given data from a single domain. An invariance relation is derived based on interventions on the spurious representation. The proposed procedure aims to identify the latent causal and spurious representations and then make predictions according to the invariance relation.
优点
-
The representation learning procedure is novel and interesting, especially the interventions on .
-
The method outperforms the baselines by a large margin on the CMNIST dataset.
缺点
-
The latent confounder is assumed to be discrete, which is restrictive. The dependency between and can be more complicated in general.
-
The identifiability of the representation is a crucial result. From the discussion in Section 4.1, the identifiability results are not trivial. I think they should be written in a formal statement and proved rigorously.
-
A claim is that is invariant across different distributions due to the removed arrows and . However, there is still an arrow , meaning that the marginal distribution of can change across different distributions. As a result, is not invariant in general.
问题
-
Whether the assumption of a discrete can be relaxed? What are the consequences of a large ?
-
Does the confounder make the invariance fail as mentioned above?
I may raise my score depending on the response. If the invariance indeed fails, I would recommend rejection.
Assumptions of : The assumptions we make about and posit them as random variables representing domain-specific and confounding information, respectively. We adhere to the standard practice of simplifying into a domain index. Importantly, we do not constrain and to be exclusively discrete or continuous. During the training process, a neural network is employed to take in and and generate parameters, including the mean and variances, for the prior distributions of . The neural network accommodates inputs of any type. However, the training procedure disentangles and by employing asymmetric prior distributions. As the dimensionality increases, optimization becomes more challenging due to potential inaccuracies in estimating domain indices, a growing number of parameters in the prior distribution, and limited improvements (if any) in disentanglement.
Invariance of : Upon careful review, we acknowledge that our assertion regarding the variation of was overstated. Our proposed interventional distribution, by setting specific values for to mitigate the influence of the domain-specific variable on , effectively accounts for the invariant confounding effects and prevents the from influencing . However, it is important to note that this framework may not generalize to new domains with different and unknown confounding effects.
In light of this, we correct our initial assumption that the confounding effects, encompassing and the causal mechanisms between , , and , remain consistent across training and test domains. This invariant confounding assumption aligns with standard practices widely adopted in works utilizing interventional distributions [1]. We will incorporate such a revision into the paper.
Nevertheless, it's crucial to underscore that with the corrected assumption, our CIIRL remains innovative and has proven its efficacy across various benchmark distribution shift datasets:
-
- With the revised assumptions, our proposed interventional distribution maintains invariance and transportability, facilitating cross-domain inference.
-
- Our training procedure does not explicitly rely on the variation of . The distinct classes of can be associated with the diverse domain variable . Our training procedure can still effectively distinguish the two types of representations.
-
- Empirical results affirm the existence of invariant latent confounders, as evidenced by the overall superior OOD prediction performance of CIIRL compared to predictions solely based on .
[1] Mao, Chengzhi, et al. "Causal transportability for visual recognition." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
Identifiability: We emphasize that our proof of identifiability, leveraging the results of [2], is mathematically correct. However, it is important to note that this proof is not our main contribution. However, we appreciate the suggestion of the reviewer and will consider dedicating a formal statement with rigorous proof for identifiability.
[2] Kivva, Bohdan, et al. "Identifiability of deep generative models without auxiliary information." Advances in Neural Information Processing Systems 35 (2022): 15687-15701.
In this paper, the authors investigate the problem of domain generalization, where the target domain datasets are unobserved during the training phases. To solve this problem, the authors propose a structural causal model with latent variables to model the causal mechanism. Sequentially, the authors conduct an intervention on the spurious representations to remove the spurious correlations and further learn the invariant interventional distribution. The authors evaluate the proposed methods on several datasets and achieve ideal performance.
优点
- The authors leverage the causal knowledge to address the domain generalization problem.
- The authors evaluate the proposed methods on several datasets.
缺点
- One important issue is the confusedness of the type of variables in Figure 1. In the domain generalization task, the domain labels are usually observed. However, it is unclear if and are observed variables or not.
- Moreover, the authors mentioned that according to Figure 2(b). But if is influenced by different domains, the aforementioned equation is not true.
- The proposed causal generation process is similar to that of [1], it is suggested that the authors should provide a discussion between the proposed causal generation process and [1]. Moreover, it seems to be impossible to conduct do-calculus on the latent variables without identification guarantees of the latent variables.
[1] Partial disentanglement for domain adaptation Lingjing Kong, Shaoan Xie, Weiran Yao, Yujia Zheng, Guangyi Chen, Petar Stojanov, Victor Akinwande, Kun Zhang Proceedings of the 39th International Conference on Machine Learning, PMLR 162:11455-11472, 2022.
问题
N.A.
Latent and : Our approach is specifically designed to enhance out-of-distribution (OOD) prediction in situations where the domain variable and confounder are not known. While domain indices are provided for specific tasks/datasets like PACS and VLCS, acquiring them for general real-world tasks poses significant challenges.
Invariance of : Upon careful review, we acknowledge that our assertion regarding the variation of was overstated. Our proposed interventional distribution, by setting specific values for to mitigate the influence of the domain-specific variable on , effectively accounts for the invariant confounding effects and prevents the from influencing . Hence, it is invariant and transportable across domains. However, it is important to note that this framework may not generalize to new domains with different and unknown confounding effects.
In light of this, we correct our initial assumption that the confounding effects, encompassing and the causal mechanisms between , , and , remain consistent across training and test domains. This invariant confounding assumption aligns with standard practices widely adopted in works utilizing interventional distributions [1]. We will incorporate such a revision into the paper.
Nevertheless, it's crucial to underscore that with the corrected assumption, our CIIRL remains innovative and has proven its efficacy across various benchmark distribution shift datasets: 1) With the revised assumptions, our proposed interventional distribution maintains invariance and transportability, facilitating cross-domain inference. 2) Our training procedure does not explicitly rely on the variation of . The distinct classes of can be associated with the diverse domain variable . Our training procedure can still effectively distinguish the two types of representations. 3) Empirical results affirm the existence of invariant latent confounders, as evidenced by the overall superior OOD prediction performance of CIIRL compared to predictions solely based on .
[1] Mao, Chengzhi, et al. "Causal transportability for visual recognition." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
Comparison to [2]: We appreciate that the reviewer recommends this relevant paper. The casual graph for the data generation process closely resembles our SCM. Specifically, our causal graph shares similarities with the one presented in [2] in the following aspects:
-
- we both separate the latent representation into (invariant) causal representation and (variant) spurious representation . The input is generated by both types of representation.
-
- The spurious representation varies from domain to domain since it is controlled by a domain-specific variable .
-
- There exists a high-level invariance confounder between and ( in their graph).
However, the graph in [2] assumes that causal features as the parent variables to target while we use the child variables.
Algorithmically, both our method and iMSDA from [2] address the confounding issue. iMSDA utilizes domain index information to estimate the confounder, whereas we focus on constructing the interventional distribution. Furthermore, iMSDA is tailored for domain adaptation tasks and necessitates access to test domain data during training. Given these distinctions, a direct comparison between our method and iMSDA would not be equitable.
[2] Partial disentanglement for domain adaptation Lingjing Kong, Shaoan Xie, Weiran Yao, Yujia Zheng, Guangyi Chen, Petar Stojanov, Victor Akinwande, Kun Zhang Proceedings of the 39th International Conference on Machine Learning, PMLR 162:11455-11472, 2022.
Identifiability of representation : We strongly agree that establishing the identifiability of the latent variables is crucial for performing do-calculus on them. Therefore, we leverage the theoretical results presented by Kivva et al. (2022) to demonstrate the identifiability of the obtained through our learning framework. Please refer to the paragraph below the algorithm in Section 3.2 for more details.
[3] Kivva, Bohdan, et al. "Identifiability of deep generative models without auxiliary information." Advances in Neural Information Processing Systems 35 (2022): 15687-15701.
The paper deals with domain generalization, where the target domain datasets are unobserved. This work intervenes on spurious representations to remove correlations and learn an invariant distribution. Their method performs well across various datasets, essentially focusing on learning causal and spurious representations to guide predictions based on an invariance relation.
All four reviews are toward rejection with ratings of 3, 3, 5, and 5 with confidence of 4, 4, 3, and 3 respectively. The issues raised by the reviews are critical including the theoretical plausibility (RUwh, kccw, Rf1b), and the clarity (ni4b).
为何不给更高分
The issues of theoretical plausibility and clarity.
为何不给更低分
N/A
Reject