Skip to content

Commit 94fe689

Browse files
committed
Add tests
1 parent a074e1f commit 94fe689

8 files changed

Lines changed: 354 additions & 1 deletion

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: Pipeline tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
pipeline-tests:
11+
runs-on: ubuntu-latest
12+
timeout-minutes: 20
13+
14+
steps:
15+
- uses: actions/checkout@v4
16+
17+
- uses: actions/setup-python@v5
18+
with:
19+
python-version: "3.11"
20+
cache: pip
21+
22+
- name: Install CPU-only torch then package
23+
run: |
24+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
25+
pip install -e ".[test]"
26+
27+
- name: Run pipeline tests
28+
run: |
29+
pytest tests/test_pretrain.py \
30+
tests/test_finetune_seg.py \
31+
tests/test_train_reg.py \
32+
tests/test_test_cls.py \
33+
tests/test_linear_probe.py \
34+
-v --timeout=300

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ extras = [
5858
]
5959
test = [
6060
"ruff>=0.14.8",
61-
"pytest>=9.0.1"
61+
"pytest>=9.0.1",
62+
"pytest-timeout>=2.3.0",
6263
]
6364
docs = [
6465
"mkdocs-shadcn",

tests/conftest.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pickle
2+
import pytest
3+
import torch
4+
from lightning import Trainer
5+
6+
7+
@pytest.fixture
8+
def pretrain_files(tmp_path):
9+
"""Three .pt files of shape [1, 32, 32, 32] for pretraining (raw image, no label).
10+
32^3 ensures the UNet bottleneck (4 max-pool stages) stays at 2x2x2, avoiding
11+
single-element BatchNorm errors with batch_size=1.
12+
"""
13+
files = []
14+
for i in range(3):
15+
path = tmp_path / f"pre_{i:03d}.pt"
16+
torch.save(torch.randn(1, 32, 32, 32), path)
17+
files.append(str(path))
18+
return {"train": files[:2], "val": [files[2]]}
19+
20+
21+
@pytest.fixture
22+
def seg_files(tmp_path):
23+
"""Three .pt + .pkl file pairs for segmentation. Shape [2, 32, 32, 32] = [image, label].
24+
32^3 ensures the UNet bottleneck (4 max-pool stages) stays at 2x2x2.
25+
"""
26+
files = []
27+
for i in range(3):
28+
pt = tmp_path / f"seg_{i:03d}.pt"
29+
pkl = tmp_path / f"seg_{i:03d}.pkl"
30+
data = torch.zeros(2, 32, 32, 32)
31+
data[0] = torch.randn(32, 32, 32)
32+
data[1] = torch.randint(0, 2, (32, 32, 32)).float()
33+
torch.save(data, pt)
34+
with open(pkl, "wb") as f:
35+
pickle.dump({"foreground_locations": []}, f)
36+
files.append(str(pt))
37+
return {"train": files[:2], "val": [files[2]]}
38+
39+
40+
@pytest.fixture
41+
def clsreg_files(tmp_path):
42+
"""Three .pt files containing (image[1,32,32,32], label_scalar) tuples.
43+
32^3 prevents single-element BatchNorm errors in the 4-stage UNet encoder.
44+
Labels are 0-dim int tensors; ClassificationModule.on_before_batch_transfer
45+
squeezes and converts to long before the training step.
46+
"""
47+
files = []
48+
for i in range(3):
49+
path = tmp_path / f"cls_{i:03d}.pt"
50+
torch.save((torch.randn(1, 32, 32, 32), torch.tensor(i % 2)), path)
51+
files.append(str(path))
52+
return {"train": files[:2], "val": [files[2]], "test": [files[2]]}
53+
54+
55+
@pytest.fixture
56+
def reg_files(tmp_path):
57+
"""Three .pt files containing (image[1,32,32,32], label[1]) tuples.
58+
Labels are 1D float tensors so they collate to [B, 1], matching the
59+
unet_clsreg_tiny output shape [B, 1] expected by MeanSquaredError.
60+
"""
61+
files = []
62+
for i in range(3):
63+
path = tmp_path / f"reg_{i:03d}.pt"
64+
torch.save((torch.randn(1, 32, 32, 32), torch.tensor([float(i % 2)])), path)
65+
files.append(str(path))
66+
return {"train": files[:2], "val": [files[2]], "test": [files[2]]}
67+
68+
69+
@pytest.fixture
70+
def cls_probe_files(tmp_path):
71+
"""Five .pt files for classification / linear-probe tests. 0-dim integer labels.
72+
2 train + 2 val gives full batches when batch_size=2, avoiding the squeeze()-to-scalar
73+
edge case in ClassificationModule.on_before_batch_transfer with batch_size=1.
74+
2 test files (labels 1, 0) ensure both classes are present for AUROC computation.
75+
"""
76+
labels = [0, 1, 0, 1, 0, 1]
77+
files = []
78+
for i, lbl in enumerate(labels):
79+
path = tmp_path / f"clsp_{i:03d}.pt"
80+
torch.save((torch.randn(1, 32, 32, 32), torch.tensor(lbl)), path)
81+
files.append(str(path))
82+
return {"train": files[:2], "val": files[2:4], "test": files[4:6]}
83+
84+
85+
@pytest.fixture
86+
def make_trainer(tmp_path):
87+
"""Factory fixture that builds a minimal CPU Trainer for smoke tests."""
88+
89+
def _make(**kwargs):
90+
defaults = dict(
91+
accelerator="cpu",
92+
max_epochs=1,
93+
limit_train_batches=2,
94+
limit_val_batches=2,
95+
logger=False,
96+
enable_checkpointing=False,
97+
enable_progress_bar=False,
98+
num_sanity_val_steps=0,
99+
)
100+
defaults.update(kwargs)
101+
return Trainer(default_root_dir=str(tmp_path), **defaults)
102+
103+
return _make

tests/test_finetune_seg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Integration test for pipeline/run/finetune_seg.py components.
2+
3+
Uses SegmentationModule + SegDataModule + unet_tiny on synthetic 8x8x8 seg volumes.
4+
Only tests trainer.fit() — trainer.test() is excluded because SegTestDataset._get_src_label()
5+
loads from ASPARAGUS_RAW_LABELS which is unavailable in CI.
6+
"""
7+
from asparagus.modules.data_modules.training import SegDataModule
8+
from asparagus.modules.lightning_modules import SegmentationModule
9+
from asparagus.modules.networks.unet import unet_tiny
10+
11+
12+
def test_finetune_seg_fit(seg_files, make_trainer):
13+
"""SegmentationModule fits from scratch (weights=None) on synthetic seg data."""
14+
model = unet_tiny(input_channels=1, output_channels=2, dimensions="3D")
15+
16+
data_module = SegDataModule(
17+
batch_size=1,
18+
num_workers=1,
19+
train_split=seg_files["train"],
20+
val_split=seg_files["val"],
21+
train_transforms=None,
22+
val_transforms=None,
23+
)
24+
25+
module = SegmentationModule(
26+
model=model,
27+
learning_rate=1e-3,
28+
warmup_epochs=0,
29+
weights=None,
30+
inference_patch_size=[32, 32, 32],
31+
)
32+
33+
make_trainer().fit(module, datamodule=data_module)

tests/test_linear_probe.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Integration test for pipeline/run/linear_probe.py components.
2+
3+
Mirrors linear_probe.py's three-phase flow: validate → fit → test.
4+
Uses LinearProbeModule + ClsRegDataModule + ResidualEncoderUNetCLSREG (tiny).
5+
6+
LinearProbeModule calls model._encode() internally, which is defined on
7+
ResidualEncoderUNetCLSREG but not on UNetCLSREG.
8+
9+
batch_size=2 is required: squeeze(-1) in on_before_batch_transfer collapses
10+
a [B] label tensor to 0-dim when B=1, causing CrossEntropyLoss to fail.
11+
limit_test_batches=2 ensures both test files (labels 0 and 1) are processed
12+
so MulticlassAUROC has both classes present.
13+
"""
14+
from asparagus.modules.data_modules.training import ClsRegDataModule
15+
from asparagus.modules.lightning_modules import LinearProbeModule
16+
from asparagus.modules.networks.resenc_unet import ResidualEncoderUNetCLSREG
17+
18+
19+
def test_linear_probe_validate_fit_test(cls_probe_files, tmp_path, make_trainer):
20+
"""LinearProbeModule runs all three phases: validate → fit → test."""
21+
model = ResidualEncoderUNetCLSREG(
22+
input_channels=1,
23+
output_channels=2,
24+
dimensions="3D",
25+
features_per_stage=(4, 8),
26+
stride=2,
27+
kernel_size=3,
28+
n_blocks_per_stage=(1, 1),
29+
)
30+
31+
data_module = ClsRegDataModule(
32+
batch_size=2,
33+
num_workers=2, # val_dataloader uses num_workers//2; needs >=2
34+
train_split=cls_probe_files["train"],
35+
val_split=cls_probe_files["val"],
36+
test_samples=cls_probe_files["test"],
37+
use_random_datasampler=False,
38+
)
39+
40+
module = LinearProbeModule(
41+
model=model,
42+
learning_rates=[0.1, 0.01],
43+
num_classes=2,
44+
dimensions="3D",
45+
test_output_path=str(tmp_path / "probe_preds.json"),
46+
weights=None,
47+
)
48+
49+
trainer = make_trainer(limit_test_batches=2)
50+
data_module.setup("fit")
51+
trainer.validate(module, datamodule=data_module)
52+
trainer.fit(module, datamodule=data_module)
53+
trainer.test(module, datamodule=data_module)

tests/test_pretrain.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Integration test for pipeline/run/pretrain.py components.
2+
3+
Uses SelfSupervisedModule + PretrainDataModule + unet_tiny on synthetic 8x8x8 volumes.
4+
Torch_CopyImageToLabel adds batch["label"] so the SSL reconstruction loss can run.
5+
"""
6+
from gardening_tools.modules.transforms.copy_image_to_label import Torch_CopyImageToLabel
7+
from torchvision import transforms
8+
9+
from asparagus.modules.data_modules.pretraining import PretrainDataModule
10+
from asparagus.modules.lightning_modules import SelfSupervisedModule
11+
from asparagus.modules.networks.unet import unet_tiny
12+
13+
14+
def test_pretrain_fit(pretrain_files, make_trainer):
15+
"""SelfSupervisedModule fits on synthetic pretrain data with reconstruction loss."""
16+
model = unet_tiny(input_channels=1, output_channels=1, dimensions="3D")
17+
18+
# CopyImageToLabel saves label = image before any GPU augmentation,
19+
# which is all the SSL reconstruction loss requires.
20+
copy_transform = transforms.Compose([Torch_CopyImageToLabel(copy=True)])
21+
22+
data_module = PretrainDataModule(
23+
batch_size=1,
24+
num_workers=1,
25+
train_split=pretrain_files["train"],
26+
val_split=pretrain_files["val"],
27+
train_transforms=copy_transform,
28+
val_transforms=copy_transform,
29+
)
30+
31+
module = SelfSupervisedModule(
32+
model=model,
33+
learning_rate=1e-3,
34+
warmup_epochs=0,
35+
train_transforms=None,
36+
val_transforms=None,
37+
)
38+
39+
make_trainer().fit(module, datamodule=data_module)

tests/test_test_cls.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Integration test for pipeline/run/test_cls.py components.
2+
3+
Mirrors test_cls.py's two-stage flow:
4+
1. Train briefly, save a checkpoint.
5+
2. Load that checkpoint, run test-time inference on new data.
6+
Uses ClassificationModule + ClsRegDataModule + unet_clsreg_tiny.
7+
8+
Note: batch_size=2 is required. ClassificationModule.on_before_batch_transfer
9+
uses squeeze() on labels; with batch_size=1 this collapses [B] to 0-dim,
10+
causing CrossEntropyLoss to fail with "batch_size (1) vs (0)".
11+
"""
12+
from asparagus.modules.data_modules.training import ClsRegDataModule
13+
from asparagus.modules.lightning_modules import ClassificationModule
14+
from asparagus.modules.networks.unet import unet_clsreg_tiny
15+
from asparagus.pipeline.auto_configuration.checkpoint import load_checkpoint_state_dict
16+
17+
18+
def test_test_cls_inference(cls_probe_files, tmp_path, make_trainer):
19+
"""ClassificationModule runs test-time inference from a saved checkpoint."""
20+
ckpt_path = tmp_path / "cls_checkpoint.ckpt"
21+
22+
# --- Stage 1: train and save a checkpoint ---
23+
train_model = unet_clsreg_tiny(input_channels=1, output_channels=2, dimensions="3D")
24+
train_module = ClassificationModule(
25+
model=train_model,
26+
learning_rate=1e-3,
27+
warmup_epochs=0,
28+
test_output_path=str(tmp_path / "train_preds.json"),
29+
)
30+
train_dm = ClsRegDataModule(
31+
batch_size=2,
32+
num_workers=2, # val_dataloader uses num_workers//2; needs >=2
33+
train_split=cls_probe_files["train"],
34+
val_split=cls_probe_files["val"],
35+
use_random_datasampler=False,
36+
)
37+
train_trainer = make_trainer()
38+
train_trainer.fit(train_module, datamodule=train_dm)
39+
train_trainer.save_checkpoint(str(ckpt_path))
40+
41+
# --- Stage 2: load weights and run inference (mirrors test_cls.py logic) ---
42+
weights = load_checkpoint_state_dict(str(ckpt_path))
43+
infer_model = unet_clsreg_tiny(input_channels=1, output_channels=2, dimensions="3D")
44+
infer_module = ClassificationModule(
45+
model=infer_model,
46+
weights=weights,
47+
test_output_path=str(tmp_path / "test_preds.json"),
48+
)
49+
test_dm = ClsRegDataModule(
50+
batch_size=2,
51+
num_workers=2, # val_dataloader uses num_workers//2; needs >=2
52+
train_split=None,
53+
val_split=None,
54+
test_samples=cls_probe_files["test"],
55+
use_random_datasampler=False,
56+
)
57+
make_trainer(limit_test_batches=2).test(infer_module, datamodule=test_dm)

tests/test_train_reg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Integration test for pipeline/run/train_reg.py components.
2+
3+
Uses RegressionModule + ClsRegDataModule + unet_clsreg_tiny on synthetic (image, label) data.
4+
Runs both trainer.fit() and trainer.test() mirroring the full pipeline.
5+
"""
6+
from asparagus.modules.data_modules.training import ClsRegDataModule
7+
from asparagus.modules.lightning_modules import RegressionModule
8+
from asparagus.modules.networks.unet import unet_clsreg_tiny
9+
10+
11+
def test_train_reg_fit_and_test(reg_files, tmp_path, make_trainer):
12+
"""RegressionModule fits then runs inference on synthetic data."""
13+
model = unet_clsreg_tiny(input_channels=1, output_channels=1, dimensions="3D")
14+
15+
data_module = ClsRegDataModule(
16+
batch_size=1,
17+
num_workers=2, # val_dataloader uses num_workers//2; needs >=2
18+
train_split=reg_files["train"],
19+
val_split=reg_files["val"],
20+
test_samples=reg_files["test"],
21+
use_random_datasampler=False,
22+
)
23+
24+
module = RegressionModule(
25+
model=model,
26+
learning_rate=1e-3,
27+
warmup_epochs=0,
28+
test_output_path=str(tmp_path / "preds.json"),
29+
)
30+
31+
trainer = make_trainer(limit_test_batches=1)
32+
trainer.fit(module, datamodule=data_module)
33+
trainer.test(module, datamodule=data_module)

0 commit comments

Comments
 (0)