Python API Reference
This page is generated from the live
ncr
Python package (pip install -e .
).
Only functions that are currently implemented are listed.
Data Generation
generate_linear_sem
Generate data from a linear structural equation model with two environments.
from ncr import generate_linear_sem
data = generate_linear_sem(
n_samples = 1000,
p = 5,
gamma = np.array([1, .5, 0, 0, 0]),
beta = np.array([2, 1, 0, 0, 0]),
env_shift = 2.0,
seed = 42)
X
– (n_samples × p) covariate matrixY
– target vectorenv
– 0 (reference) / 1 (shifted)
generate_colored_mnist
Create an IRM‑style coloured‑MNIST torch.utils.data.Dataset
.
from ncr import generate_colored_mnist
train_ds = generate_colored_mnist(root="data", env_prob=0.1, train=True)
Model Training
NCRNet
Small MLP used in toy experiments.
from ncr import NCRNet
net = NCRNet(input_dim=28*28, hidden_dim=256, n_hidden=2)
train_ncr
from ncr import train_ncr, risk_variance
history = train_ncr(net, loaders=[loader0, loader1],
lambda_reg=10.0,
penalty=risk_variance,
n_epochs=10)
history
is a pandas.DataFrame
(one row per epoch).
Penalty helpers
risk_variance(*risks)
risk_gap(*risks)
Evaluation & Metrics
accuracy
Fraction of correct predictions (0–1). Multiply by 100 for %.
from ncr import accuracy
acc = accuracy(net, loader) * 100
print(f"{acc:.2f}%")
evaluate_model
Return MSE on each loader passed.
compare_models
Concatenate a list of result series into a comparison table.
Visualisation
plot_training_history
from ncr import plot_training_history
plot_training_history(history)
plot_feature_importance
Permutation‑based feature importance for trained NCR models.
Utilities
ncr.__all__
lists every public symbol.