Quick start
Train NCR on Colored‑MNIST
import torch, torchvision
from torch.utils.data import DataLoader
from ncr import NCRNet, train_ncr, risk_gap, accuracy
# 1. two synthetic environments
def colour_fn(env):
red, green = (1,0,0), (0,1,0)
def f(img):
rgb = red if (img.sum()&1) ^ env else green
return torch.tensor(rgb)[:,None,None] * torchvision.transforms.ToTensor()(img)
return f
root = "./data"
base = torchvision.datasets.MNIST(root, train=True, download=True)
env0 = [(colour_fn(0)(img), torch.tensor([y<5],dtype=torch.float32)) for img,y in base]
env1 = [(colour_fn(1)(img), torch.tensor([y<5],dtype=torch.float32)) for img,y in base]
loaders = [DataLoader(env0, batch_size=64, shuffle=True),
DataLoader(env1, batch_size=64, shuffle=True)]
# 2. model & loss
net = NCRNet(3*28*28, (256,256), 1)
crit = torch.nn.BCEWithLogitsLoss()
# 3. train with NCR (λ=1)
train_ncr(net, loaders, crit, epochs=5, lambda_reg=1.0)
# 4. evaluate
print("OOD acc:", accuracy(net, loaders[1]))
This tiny script matches the workflow in our experiments
folder. For a full benchmark, see the
Examples.