Basic Example: Linear Structural Equation Model

Learn the fundamentals of Neural Causal Regularization with a simple linear structural equation model

Introduction

In this example, we'll demonstrate how to use Neural Causal Regularization (NCR) with a simple linear structural equation model (SEM). We'll generate synthetic data with known causal structure, train models with and without causal regularization, and compare their performance on out-of-distribution data.

Setup

First, let's load the necessary packages:

# Load packages
library(ncausalreg)
library(torch)
library(ggplot2)
library(dplyr)

# Set random seed for reproducibility
set.seed(123)
torch_manual_seed(123)

Data Generation

We'll generate data from a linear structural equation model with 5 covariates, where only the first two are causally related to the target variable. The data will be generated for two environments: a reference environment and a shifted environment.

# Define parameters
p <- 5  # Number of features
n_samples <- 1000  # Number of samples per environment

# Define causal parameters
gamma <- c(0.8, 0.6, 0.4, 0.2, 0.1)  # Effect of shift variable on covariates
beta <- c(1.0, 0.5, 0, 0, 0)  # Only first two variables are causal

# Generate data
data <- generate_linear_sem(
  n_samples = n_samples,
  p = p,
  gamma = gamma,
  beta = beta,
  env_shift = 1.5,  # Strength of environment shift
  noise_scale = 0.1,  # Scale of noise terms
  seed = 123
)

# Examine the data structure
str(data)

The generated data includes:

Let's visualize the relationship between the covariates and the target variable in each environment:

# Create a data frame for visualization
df <- data.frame(
  Y = data$Y,
  X1 = data$X[, 1],
  X2 = data$X[, 2],
  X3 = data$X[, 3],
  X4 = data$X[, 4],
  X5 = data$X[, 5],
  Environment = factor(data$env, labels = c("Reference", "Shifted"))
)

# Plot the relationship between X1 and Y
ggplot(df, aes(x = X1, y = Y, color = Environment)) +
  geom_point(alpha = 0.5) +
  geom_smooth(method = "lm", se = FALSE) +
  theme_minimal() +
  labs(title = "Relationship between X1 and Y",
       x = "X1", y = "Y")

# Plot the relationship between X3 and Y
ggplot(df, aes(x = X3, y = Y, color = Environment)) +
  geom_point(alpha = 0.5) +
  geom_smooth(method = "lm", se = FALSE) +
  theme_minimal() +
  labs(title = "Relationship between X3 and Y",
       x = "X3", y = "Y")
Relationship between X1 and Y

Figure 1: Relationship between X1 (causal) and Y across environments. The relationship is stable across environments.

Relationship between X3 and Y

Figure 2: Relationship between X3 (non-causal) and Y across environments. The relationship changes across environments.

Notice that the relationship between X1 (a causal feature) and Y is stable across environments, while the relationship between X3 (a non-causal feature) and Y changes across environments. This is a key insight that Neural Causal Regularization exploits.

Data Preparation

Next, we'll split the data by environment and prepare it for training:

# Split data by environment
env0_idx <- which(data$env == 0)
env1_idx <- which(data$env == 1)

# Create training data lists
x_train_list <- list(data$X[env0_idx, ], data$X[env1_idx, ])
y_train_list <- list(data$Y[env0_idx], data$Y[env1_idx])

# Convert to torch tensors
x_train_list_torch <- lapply(x_train_list, function(x) torch_tensor(x, dtype = torch_float32()))
y_train_list_torch <- lapply(y_train_list, function(y) torch_tensor(y, dtype = torch_float32())$view(c(-1, 1)))

# Create a validation set from a mixture of both environments
val_idx <- sample(1:n_samples, size = 200)
x_val <- torch_tensor(data$X[val_idx, ], dtype = torch_float32())
y_val <- torch_tensor(data$Y[val_idx], dtype = torch_float32())$view(c(-1, 1))

Model Training

Now, let's train two models: one with Neural Causal Regularization (NCR) and one without (standard Empirical Risk Minimization, ERM):

# Create models
model_erm <- ncr_model(input_dim = p, hidden_dims = c(10, 5))
model_ncr <- ncr_model(input_dim = p, hidden_dims = c(10, 5))

# Train ERM model (lambda = 0)
result_erm <- train_ncr(
  model = model_erm,
  x_train_list = x_train_list_torch,
  y_train_list = y_train_list_torch,
  lambda_reg = 0,  # No regularization
  lr = 0.01,
  n_epochs = 100,
  batch_size = 32,
  verbose = TRUE
)

# Train NCR model (lambda = 10)
result_ncr <- train_ncr(
  model = model_ncr,
  x_train_list = x_train_list_torch,
  y_train_list = y_train_list_torch,
  lambda_reg = 10,  # Strong regularization
  lr = 0.01,
  n_epochs = 100,
  batch_size = 32,
  verbose = TRUE
)

Training Visualization

Let's visualize the training process for both models:

# Plot training history for ERM
plot_training_history(result_erm, plot_type = "loss")

# Plot training history for NCR
plot_training_history(result_ncr, plot_type = "both")
ERM Training History

Figure 3: Training history for the ERM model. Only the loss is shown since there's no variance penalty.

NCR Training History

Figure 4: Training history for the NCR model. Both the loss and variance penalty are shown.

Notice that the NCR model's variance penalty decreases during training, indicating that the model is learning to make predictions that are consistent across environments.

Out-of-Distribution Testing

To evaluate the models' ability to generalize to new environments, we'll generate test data with a stronger environment shift:

# Generate test data with stronger environment shift
test_data <- generate_linear_sem(
  n_samples = 500,
  p = p,
  gamma = gamma,
  beta = beta,
  env_shift = 3.0,  # Stronger shift for OOD testing
  noise_scale = 0.1,
  seed = 456
)

# Convert to torch tensors
x_test <- torch_tensor(test_data$X, dtype = torch_float32())
y_test <- torch_tensor(test_data$Y, dtype = torch_float32())$view(c(-1, 1))

# Evaluate models
mse_erm <- evaluate_model(result_erm$model, x_test, y_test)
mse_ncr <- evaluate_model(result_ncr$model, x_test, y_test)

# Print results
cat("ERM Test MSE:", mse_erm, "\n")
cat("NCR Test MSE:", mse_ncr, "\n")
cat("Improvement:", (mse_erm - mse_ncr) / mse_erm * 100, "%\n")

Example output:

ERM Test MSE: 0.8765
NCR Test MSE: 0.2134
Improvement: 75.65 %

The NCR model significantly outperforms the ERM model on out-of-distribution data, demonstrating the benefit of causal regularization for generalization to new environments.

Feature Importance Analysis

Let's analyze the feature importance of both models to see if NCR correctly identifies the causal features:

# Extract feature importance
importance_erm <- extract_feature_importance(result_erm$model, x_test, y_test)
importance_ncr <- extract_feature_importance(result_ncr$model, x_test, y_test)

# Add model information
importance_erm$Model <- "ERM"
importance_ncr$Model <- "NCR"

# Combine data
importance_df <- rbind(importance_erm, importance_ncr)

# Plot feature importance
ggplot(importance_df, aes(x = Feature, y = Importance, fill = Model)) +
  geom_bar(stat = "identity", position = "dodge") +
  theme_minimal() +
  labs(title = "Feature Importance Comparison",
       x = "Feature", y = "Importance") +
  scale_fill_manual(values = c("ERM" = "#1f77b4", "NCR" = "#ff7f0e"))
Feature Importance Comparison

Figure 5: Feature importance comparison between ERM and NCR models. NCR correctly assigns higher importance to the causal features (X1 and X2).

Notice that the NCR model correctly assigns higher importance to the causal features (X1 and X2) and lower importance to the non-causal features (X3, X4, and X5), while the ERM model assigns significant importance to all features.

Linear Model Comparison

For comparison, let's also train a linear model with causal regularization:

# Train linear model with causal regularization
linear_model <- train_linear_causal_reg(
  x_train_list = x_train_list,
  y_train_list = y_train_list,
  lambda_reg = 10
)

# Print model coefficients
print("Linear model coefficients:")
print(linear_model$coefficients)
print(paste("Linear model intercept:", linear_model$intercept))

# Compare with true beta
print("True beta:")
print(beta)
print(paste("Correlation with true beta:", cor(linear_model$coefficients, beta)))

Example output:

Linear model coefficients:
[1] 0.9823 0.4912 0.0214 0.0156 0.0089
Linear model intercept: 0.0123
True beta:
[1] 1.0 0.5 0.0 0.0 0.0
Correlation with true beta: 0.9987

The linear model with causal regularization recovers coefficients that are very close to the true causal parameters, demonstrating that causal regularization can help identify the true causal structure.

Conclusion

In this example, we've demonstrated how Neural Causal Regularization can improve out-of-distribution generalization and help identify causal features in a simple linear structural equation model. The key insights are:

In the next example, we'll explore how NCR performs with nonlinear structural equation models.