Theoretical Background

This page provides the theoretical foundations and mathematical background of Neural Causal Regularization.

Problem Formulation

Consider a prediction problem where we observe data from multiple environments \(e \in \mathcal{E}\). In each environment, we have covariates \(X \in \mathbb{R}^p\) and a target variable \(Y \in \mathbb{R}\). The joint distribution \(P^e(X, Y)\) may vary across environments due to distribution shifts.

Our goal is to learn a predictor \(f: \mathbb{R}^p \rightarrow \mathbb{R}\) that performs well across all environments, including unseen ones. This requires identifying stable, invariant relationships that hold across environments.

Structural Causal Models

Neural Causal Regularization is based on the framework of Structural Causal Models (SCMs). An SCM consists of a set of variables and a set of structural equations that describe how each variable is determined by other variables in the system.

In the context of our prediction problem, we can represent the data-generating process as:

X = g(A, ε_X)
Y = f*(X_S) + ε_Y

where:

The key insight is that the relationship between \(X_S\) and \(Y\) remains invariant across environments, even as the distribution of \(X\) changes.

Invariant Risk Minimization

Invariant Risk Minimization (IRM) is a framework for learning predictors that perform well across environments. The core idea is to find a data representation such that the optimal predictor on top of this representation is the same across all environments.

Formally, IRM aims to solve:

min_{f} \sum_{e \in \mathcal{E}} R^e(f) 
subject to f \in \arg\min_{f'} R^e(f') \forall e \in \mathcal{E}

where \(R^e(f)\) is the risk of predictor \(f\) in environment \(e\).

This is a challenging bi-level optimization problem. In practice, it is often relaxed to a penalty term that encourages the gradient of the risk to be similar across environments.

Neural Causal Regularization

Neural Causal Regularization (NCR) extends these ideas to deep neural networks. Instead of enforcing invariance through a bi-level optimization, NCR penalizes the variance of prediction risks across environments.

The NCR objective function is:

min_{f} \bar{R}(f) + \lambda \cdot V(f)

where:

By minimizing the variance of risks, NCR encourages the model to focus on features that lead to consistent performance across environments, which typically correspond to causal relationships.

Theoretical Guarantees

Out-of-Distribution Generalization

Under certain conditions, NCR provides guarantees on out-of-distribution generalization. Specifically, if the data is generated according to a structural causal model where the relationship between \(X_S\) and \(Y\) is invariant across environments, and if the model class is expressive enough to capture this relationship, then NCR will recover a predictor that generalizes to new environments.

Theorem 1 (Informal): Let \(f_{\lambda}\) be the solution to the NCR objective with regularization parameter \(\lambda\). As \(\lambda \rightarrow \infty\), \(f_{\lambda}\) converges to a predictor that depends only on the causal features \(X_S\) and achieves the optimal invariant risk.

Causal Feature Identification

NCR can also be used to identify causal features. By examining the feature importance of a model trained with NCR, we can distinguish between causal and non-causal features.

Theorem 2 (Informal): Under suitable conditions, as \(\lambda \rightarrow \infty\), the feature importance of non-causal features in a model trained with NCR approaches zero.

Finite-Sample Bounds

In practice, we only have finite samples from each environment. NCR provides finite-sample bounds on the out-of-distribution generalization error.

Theorem 3 (Informal): With probability at least \(1-\delta\), the out-of-distribution generalization error of a model trained with NCR is bounded by \(O(\sqrt{\frac{\log(1/\delta)}{n}})\), where \(n\) is the minimum number of samples across environments.

Connections to Other Methods

Invariant Risk Minimization (IRM)

NCR is closely related to IRM but uses a different approach to enforce invariance. While IRM uses a bi-level optimization or gradient penalties, NCR directly penalizes the variance of risks across environments.

Risk Extrapolation (REx)

REx is another approach that penalizes the difference between training risks across environments. NCR can be seen as a specific instance of REx where the penalty is the variance of risks.

Anchor Regression

Anchor regression is a method for causal inference that uses a similar objective function to NCR but is formulated in terms of linear models. NCR extends these ideas to nonlinear models using neural networks.

Practical Considerations

Choosing the Regularization Parameter

The regularization parameter \(\lambda\) controls the trade-off between average performance and invariance. A larger \(\lambda\) puts more emphasis on invariance, which can lead to better out-of-distribution generalization but potentially worse in-distribution performance.

In practice, \(\lambda\) can be chosen using cross-validation on a validation set from a different environment than the training environments.

Model Architecture

NCR can be applied to any neural network architecture. However, the choice of architecture can affect the model's ability to capture invariant relationships. In general, deeper networks with sufficient capacity are recommended.

Number of Environments

NCR requires data from multiple environments to identify invariant relationships. In general, having more diverse environments can lead to better identification of causal features. However, even with just two environments, NCR can provide benefits over standard empirical risk minimization.

Limitations

While NCR provides a powerful framework for learning invariant predictors, it has some limitations:

Further Reading

For more details on the theoretical foundations of Neural Causal Regularization, we recommend the following resources: