Learn the fundamentals of Neural Causal Regularization with a simple linear structural equation model
In this example, we'll demonstrate how to use Neural Causal Regularization (NCR) with a simple linear structural equation model (SEM). We'll generate synthetic data with known causal structure, train models with and without causal regularization, and compare their performance on out-of-distribution data.
First, let's load the necessary packages:
# Load packages
library(ncausalreg)
library(torch)
library(ggplot2)
library(dplyr)
# Set random seed for reproducibility
set.seed(123)
torch_manual_seed(123)
We'll generate data from a linear structural equation model with 5 covariates, where only the first two are causally related to the target variable. The data will be generated for two environments: a reference environment and a shifted environment.
# Define parameters
p <- 5 # Number of features
n_samples <- 1000 # Number of samples per environment
# Define causal parameters
gamma <- c(0.8, 0.6, 0.4, 0.2, 0.1) # Effect of shift variable on covariates
beta <- c(1.0, 0.5, 0, 0, 0) # Only first two variables are causal
# Generate data
data <- generate_linear_sem(
n_samples = n_samples,
p = p,
gamma = gamma,
beta = beta,
env_shift = 1.5, # Strength of environment shift
noise_scale = 0.1, # Scale of noise terms
seed = 123
)
# Examine the data structure
str(data)
The generated data includes:
X
: Matrix of covariates (n_samples x p)Y
: Vector of target variable (length n_samples)A
: Vector of shift variable (length n_samples)env
: Environment indicator (0 for reference, 1 for shifted)Let's visualize the relationship between the covariates and the target variable in each environment:
# Create a data frame for visualization
df <- data.frame(
Y = data$Y,
X1 = data$X[, 1],
X2 = data$X[, 2],
X3 = data$X[, 3],
X4 = data$X[, 4],
X5 = data$X[, 5],
Environment = factor(data$env, labels = c("Reference", "Shifted"))
)
# Plot the relationship between X1 and Y
ggplot(df, aes(x = X1, y = Y, color = Environment)) +
geom_point(alpha = 0.5) +
geom_smooth(method = "lm", se = FALSE) +
theme_minimal() +
labs(title = "Relationship between X1 and Y",
x = "X1", y = "Y")
# Plot the relationship between X3 and Y
ggplot(df, aes(x = X3, y = Y, color = Environment)) +
geom_point(alpha = 0.5) +
geom_smooth(method = "lm", se = FALSE) +
theme_minimal() +
labs(title = "Relationship between X3 and Y",
x = "X3", y = "Y")
Figure 1: Relationship between X1 (causal) and Y across environments. The relationship is stable across environments.
Figure 2: Relationship between X3 (non-causal) and Y across environments. The relationship changes across environments.
Notice that the relationship between X1 (a causal feature) and Y is stable across environments, while the relationship between X3 (a non-causal feature) and Y changes across environments. This is a key insight that Neural Causal Regularization exploits.
Next, we'll split the data by environment and prepare it for training:
# Split data by environment
env0_idx <- which(data$env == 0)
env1_idx <- which(data$env == 1)
# Create training data lists
x_train_list <- list(data$X[env0_idx, ], data$X[env1_idx, ])
y_train_list <- list(data$Y[env0_idx], data$Y[env1_idx])
# Convert to torch tensors
x_train_list_torch <- lapply(x_train_list, function(x) torch_tensor(x, dtype = torch_float32()))
y_train_list_torch <- lapply(y_train_list, function(y) torch_tensor(y, dtype = torch_float32())$view(c(-1, 1)))
# Create a validation set from a mixture of both environments
val_idx <- sample(1:n_samples, size = 200)
x_val <- torch_tensor(data$X[val_idx, ], dtype = torch_float32())
y_val <- torch_tensor(data$Y[val_idx], dtype = torch_float32())$view(c(-1, 1))
Now, let's train two models: one with Neural Causal Regularization (NCR) and one without (standard Empirical Risk Minimization, ERM):
# Create models
model_erm <- ncr_model(input_dim = p, hidden_dims = c(10, 5))
model_ncr <- ncr_model(input_dim = p, hidden_dims = c(10, 5))
# Train ERM model (lambda = 0)
result_erm <- train_ncr(
model = model_erm,
x_train_list = x_train_list_torch,
y_train_list = y_train_list_torch,
lambda_reg = 0, # No regularization
lr = 0.01,
n_epochs = 100,
batch_size = 32,
verbose = TRUE
)
# Train NCR model (lambda = 10)
result_ncr <- train_ncr(
model = model_ncr,
x_train_list = x_train_list_torch,
y_train_list = y_train_list_torch,
lambda_reg = 10, # Strong regularization
lr = 0.01,
n_epochs = 100,
batch_size = 32,
verbose = TRUE
)
Let's visualize the training process for both models:
# Plot training history for ERM
plot_training_history(result_erm, plot_type = "loss")
# Plot training history for NCR
plot_training_history(result_ncr, plot_type = "both")
Figure 3: Training history for the ERM model. Only the loss is shown since there's no variance penalty.
Figure 4: Training history for the NCR model. Both the loss and variance penalty are shown.
Notice that the NCR model's variance penalty decreases during training, indicating that the model is learning to make predictions that are consistent across environments.
To evaluate the models' ability to generalize to new environments, we'll generate test data with a stronger environment shift:
# Generate test data with stronger environment shift
test_data <- generate_linear_sem(
n_samples = 500,
p = p,
gamma = gamma,
beta = beta,
env_shift = 3.0, # Stronger shift for OOD testing
noise_scale = 0.1,
seed = 456
)
# Convert to torch tensors
x_test <- torch_tensor(test_data$X, dtype = torch_float32())
y_test <- torch_tensor(test_data$Y, dtype = torch_float32())$view(c(-1, 1))
# Evaluate models
mse_erm <- evaluate_model(result_erm$model, x_test, y_test)
mse_ncr <- evaluate_model(result_ncr$model, x_test, y_test)
# Print results
cat("ERM Test MSE:", mse_erm, "\n")
cat("NCR Test MSE:", mse_ncr, "\n")
cat("Improvement:", (mse_erm - mse_ncr) / mse_erm * 100, "%\n")
Example output:
ERM Test MSE: 0.8765
NCR Test MSE: 0.2134
Improvement: 75.65 %
The NCR model significantly outperforms the ERM model on out-of-distribution data, demonstrating the benefit of causal regularization for generalization to new environments.
Let's analyze the feature importance of both models to see if NCR correctly identifies the causal features:
# Extract feature importance
importance_erm <- extract_feature_importance(result_erm$model, x_test, y_test)
importance_ncr <- extract_feature_importance(result_ncr$model, x_test, y_test)
# Add model information
importance_erm$Model <- "ERM"
importance_ncr$Model <- "NCR"
# Combine data
importance_df <- rbind(importance_erm, importance_ncr)
# Plot feature importance
ggplot(importance_df, aes(x = Feature, y = Importance, fill = Model)) +
geom_bar(stat = "identity", position = "dodge") +
theme_minimal() +
labs(title = "Feature Importance Comparison",
x = "Feature", y = "Importance") +
scale_fill_manual(values = c("ERM" = "#1f77b4", "NCR" = "#ff7f0e"))
Figure 5: Feature importance comparison between ERM and NCR models. NCR correctly assigns higher importance to the causal features (X1 and X2).
Notice that the NCR model correctly assigns higher importance to the causal features (X1 and X2) and lower importance to the non-causal features (X3, X4, and X5), while the ERM model assigns significant importance to all features.
For comparison, let's also train a linear model with causal regularization:
# Train linear model with causal regularization
linear_model <- train_linear_causal_reg(
x_train_list = x_train_list,
y_train_list = y_train_list,
lambda_reg = 10
)
# Print model coefficients
print("Linear model coefficients:")
print(linear_model$coefficients)
print(paste("Linear model intercept:", linear_model$intercept))
# Compare with true beta
print("True beta:")
print(beta)
print(paste("Correlation with true beta:", cor(linear_model$coefficients, beta)))
Example output:
Linear model coefficients:
[1] 0.9823 0.4912 0.0214 0.0156 0.0089
Linear model intercept: 0.0123
True beta:
[1] 1.0 0.5 0.0 0.0 0.0
Correlation with true beta: 0.9987
The linear model with causal regularization recovers coefficients that are very close to the true causal parameters, demonstrating that causal regularization can help identify the true causal structure.
In this example, we've demonstrated how Neural Causal Regularization can improve out-of-distribution generalization and help identify causal features in a simple linear structural equation model. The key insights are:
In the next example, we'll explore how NCR performs with nonlinear structural equation models.