Quick start

Train NCR on Colored‑MNIST

import torch, torchvision
from torch.utils.data import DataLoader
from ncr import NCRNet, train_ncr, risk_gap, accuracy

# 1. two synthetic environments
def colour_fn(env):
    red, green = (1,0,0), (0,1,0)
    def f(img):
        rgb = red if (img.sum()&1) ^ env else green
        return torch.tensor(rgb)[:,None,None] * torchvision.transforms.ToTensor()(img)
    return f

root = "./data"
base = torchvision.datasets.MNIST(root, train=True, download=True)
env0 = [(colour_fn(0)(img), torch.tensor([y<5],dtype=torch.float32)) for img,y in base]
env1 = [(colour_fn(1)(img), torch.tensor([y<5],dtype=torch.float32)) for img,y in base]
loaders = [DataLoader(env0, batch_size=64, shuffle=True),
           DataLoader(env1, batch_size=64, shuffle=True)]

# 2. model & loss
net = NCRNet(3*28*28, (256,256), 1)
crit = torch.nn.BCEWithLogitsLoss()

# 3. train with NCR (λ=1)
train_ncr(net, loaders, crit, epochs=5, lambda_reg=1.0)

# 4. evaluate
print("OOD acc:", accuracy(net, loaders[1]))

This tiny script matches the workflow in our experiments folder. For a full benchmark, see the Examples.