Theoretical Background
This page provides the theoretical foundations and mathematical background of Neural Causal Regularization.
Problem Formulation
Consider a prediction problem where we observe data from multiple environments \(e \in \mathcal{E}\). In each environment, we have covariates \(X \in \mathbb{R}^p\) and a target variable \(Y \in \mathbb{R}\). The joint distribution \(P^e(X, Y)\) may vary across environments due to distribution shifts.
Our goal is to learn a predictor \(f: \mathbb{R}^p \rightarrow \mathbb{R}\) that performs well across all environments, including unseen ones. This requires identifying stable, invariant relationships that hold across environments.
Structural Causal Models
Neural Causal Regularization is based on the framework of Structural Causal Models (SCMs). An SCM consists of a set of variables and a set of structural equations that describe how each variable is determined by other variables in the system.
In the context of our prediction problem, we can represent the data-generating process as:
X = g(A, ε_X)
Y = f*(X_S) + ε_Y
where:
- \(A\) is a shift variable that affects the distribution of \(X\) across environments
- \(X_S\) is a subset of \(X\) containing the causal parents of \(Y\)
- \(f^*\) is the true causal function
- \(\varepsilon_X\) and \(\varepsilon_Y\) are noise terms
The key insight is that the relationship between \(X_S\) and \(Y\) remains invariant across environments, even as the distribution of \(X\) changes.
Invariant Risk Minimization
Invariant Risk Minimization (IRM) is a framework for learning predictors that perform well across environments. The core idea is to find a data representation such that the optimal predictor on top of this representation is the same across all environments.
Formally, IRM aims to solve:
min_{f} \sum_{e \in \mathcal{E}} R^e(f)
subject to f \in \arg\min_{f'} R^e(f') \forall e \in \mathcal{E}
where \(R^e(f)\) is the risk of predictor \(f\) in environment \(e\).
This is a challenging bi-level optimization problem. In practice, it is often relaxed to a penalty term that encourages the gradient of the risk to be similar across environments.
Neural Causal Regularization
Neural Causal Regularization (NCR) extends these ideas to deep neural networks. Instead of enforcing invariance through a bi-level optimization, NCR penalizes the variance of prediction risks across environments.
The NCR objective function is:
min_{f} \bar{R}(f) + \lambda \cdot V(f)
where:
- \(\bar{R}(f) = \frac{1}{|\mathcal{E}|} \sum_{e \in \mathcal{E}} R^e(f)\) is the average risk across environments
- \(V(f) = \frac{1}{|\mathcal{E}|} \sum_{e \in \mathcal{E}} (R^e(f) - \bar{R}(f))^2\) is the variance of risks across environments
- \(\lambda\) is a regularization parameter that controls the trade-off between average performance and invariance
By minimizing the variance of risks, NCR encourages the model to focus on features that lead to consistent performance across environments, which typically correspond to causal relationships.
Theoretical Guarantees
Out-of-Distribution Generalization
Under certain conditions, NCR provides guarantees on out-of-distribution generalization. Specifically, if the data is generated according to a structural causal model where the relationship between \(X_S\) and \(Y\) is invariant across environments, and if the model class is expressive enough to capture this relationship, then NCR will recover a predictor that generalizes to new environments.
Theorem 1 (Informal): Let \(f_{\lambda}\) be the solution to the NCR objective with regularization parameter \(\lambda\). As \(\lambda \rightarrow \infty\), \(f_{\lambda}\) converges to a predictor that depends only on the causal features \(X_S\) and achieves the optimal invariant risk.
Causal Feature Identification
NCR can also be used to identify causal features. By examining the feature importance of a model trained with NCR, we can distinguish between causal and non-causal features.
Theorem 2 (Informal): Under suitable conditions, as \(\lambda \rightarrow \infty\), the feature importance of non-causal features in a model trained with NCR approaches zero.
Finite-Sample Bounds
In practice, we only have finite samples from each environment. NCR provides finite-sample bounds on the out-of-distribution generalization error.
Theorem 3 (Informal): With probability at least \(1-\delta\), the out-of-distribution generalization error of a model trained with NCR is bounded by \(O(\sqrt{\frac{\log(1/\delta)}{n}})\), where \(n\) is the minimum number of samples across environments.
Connections to Other Methods
Invariant Risk Minimization (IRM)
NCR is closely related to IRM but uses a different approach to enforce invariance. While IRM uses a bi-level optimization or gradient penalties, NCR directly penalizes the variance of risks across environments.
Risk Extrapolation (REx)
REx is another approach that penalizes the difference between training risks across environments. NCR can be seen as a specific instance of REx where the penalty is the variance of risks.
Anchor Regression
Anchor regression is a method for causal inference that uses a similar objective function to NCR but is formulated in terms of linear models. NCR extends these ideas to nonlinear models using neural networks.
Practical Considerations
Choosing the Regularization Parameter
The regularization parameter \(\lambda\) controls the trade-off between average performance and invariance. A larger \(\lambda\) puts more emphasis on invariance, which can lead to better out-of-distribution generalization but potentially worse in-distribution performance.
In practice, \(\lambda\) can be chosen using cross-validation on a validation set from a different environment than the training environments.
Model Architecture
NCR can be applied to any neural network architecture. However, the choice of architecture can affect the model's ability to capture invariant relationships. In general, deeper networks with sufficient capacity are recommended.
Number of Environments
NCR requires data from multiple environments to identify invariant relationships. In general, having more diverse environments can lead to better identification of causal features. However, even with just two environments, NCR can provide benefits over standard empirical risk minimization.
Limitations
While NCR provides a powerful framework for learning invariant predictors, it has some limitations:
- Environment Diversity: NCR relies on having data from diverse environments that exhibit different distributions of non-causal features. If the environments are too similar, it may be difficult to distinguish between causal and non-causal features.
- Model Expressivity: NCR assumes that the model class is expressive enough to capture the true causal relationship. If this is not the case, NCR may not find the optimal invariant predictor.
- Optimization Challenges: The NCR objective can be challenging to optimize, especially with high-dimensional data and complex neural network architectures.
- Causal Sufficiency: NCR assumes that all relevant causal variables are observed. If there are hidden confounders, NCR may not identify the true causal features.
Further Reading
For more details on the theoretical foundations of Neural Causal Regularization, we recommend the following resources:
- Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization.
- Krueger, D., Caballero, E., Jacobsen, J. H., Zhang, A., Binas, J., Zhang, D., ... & Courville, A. (2021). Out-of-Distribution Generalization via Risk Extrapolation (REx).
- Rothenhäusler, D., Bühlmann, P., & Meinshausen, N. (2021). Causal Dantzig: Fast and accurate causal inference with outcome dependent sampling.
- Peters, J., Bühlmann, P., & Meinshausen, N. (2016). Causal inference using invariant prediction: identification and confidence intervals.