Skip to content

Commit 7cfa758

Browse files
authored
Fix per-atlas overwrite in single_session_metrics + atlas dtype (#344)
`compute_timeseries` derives its output filename from the BOLD stem, so every atlas iteration in `single_session_metrics` was writing to the same file. Every `MetricsOutputs.timeseries[label]` then pointed at the last-processed atlas's data, silently corrupting the regular RBC pipeline's atlas outputs (timeseries + Pearson correlations). Give each atlas its own `out_dir`. Also switch `compute_timeseries` from `get_fdata().astype(int)` to `np.asarray(atlas_img.dataobj).astype(int)` so integer atlas labels survive verbatim. `get_fdata` would apply `scl_slope`/`scl_inter` and scale small labels into garbage floats if an atlas mistakenly ships with non-trivial scaling. Regression test in `tests/unit/workflows/test_metrics.py` builds two atlases with different ROI counts (3 and 5) and asserts each is preserved in `MetricsOutputs`, with distinct file paths.
1 parent bdf0ab9 commit 7cfa758

4 files changed

Lines changed: 121 additions & 3 deletions

File tree

src/rbc/core/metrics/timeseries.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def compute_timeseries(
148148
if atlas_img.shape[:3] != img.shape[:3]:
149149
atlas_img = resample_from_to(atlas_img, (img.shape[:3], img.affine), order=0)
150150

151-
atlas_data = atlas_img.get_fdata().astype(int)
151+
# Read via ``dataobj`` so the on-disk integer labels survive verbatim;
152+
# ``get_fdata`` would apply ``scl_slope``/``scl_inter`` and scale small
153+
# labels into garbage floats if the atlas ships with non-trivial scaling.
154+
atlas_data = np.asarray(atlas_img.dataobj).astype(int)
152155

153156
data = img.get_fdata()
154157
labels = np.unique(atlas_data)

src/rbc/workflows/metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,15 @@ def single_session_metrics(
108108
reho_zscored_path = compute_zscore(reho_smooth_path, template_brain_mask)
109109

110110
# 5. Atlas timeseries + correlation matrix from nuisance-regressed,
111-
# bandpass-filtered BOLD
111+
# bandpass-filtered BOLD. Each atlas needs its own ``out_dir`` so the
112+
# BOLD-stem-derived output filename doesn't collide across atlases.
112113
ts_outputs = {}
113114
for label, atlas_path in atlas_files.items():
114115
_logger.info("Extracting atlas timeseries (%s)", label)
116+
atlas_dir = work_dir / f"atlas-{label}"
117+
atlas_dir.mkdir(parents=True, exist_ok=True)
115118
ts_outputs[label] = compute_timeseries(
116-
cleaned_bold, atlas_path, out_dir=work_dir
119+
cleaned_bold, atlas_path, out_dir=atlas_dir
117120
)
118121

119122
return MetricsOutputs(

tests/unit/workflows/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Workflow module tests."""
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Unit tests for rbc.workflows.metrics."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import nibabel as nib
8+
import numpy as np
9+
10+
from rbc.workflows import metrics as metrics_mod
11+
from rbc.workflows.metrics import single_session_metrics
12+
13+
if TYPE_CHECKING:
14+
from pathlib import Path
15+
16+
import pytest
17+
18+
19+
def _save_nifti(path: Path, data: np.ndarray) -> None:
20+
nib.nifti1.Nifti1Image(data, affine=np.eye(4)).to_filename(str(path))
21+
22+
23+
def test_atlas_outputs_are_per_atlas(
24+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
25+
) -> None:
26+
"""Each atlas's timeseries must land in its own file with its own ROI count.
27+
28+
Regression for a prior single_session_metrics bug where every atlas's
29+
``compute_timeseries`` call shared the same ``out_dir`` and overwrote
30+
the previous atlas's file; ``MetricsOutputs.timeseries[label]`` ended
31+
up pointing at the last-iterated atlas's data for every key.
32+
"""
33+
rng = np.random.default_rng(0)
34+
bold = rng.standard_normal((6, 6, 6, 8))
35+
mask = np.ones((6, 6, 6), dtype=np.int16)
36+
37+
atlas_3 = np.zeros((6, 6, 6), dtype=np.int16)
38+
atlas_3[0:2] = 1
39+
atlas_3[2:4] = 2
40+
atlas_3[4:6] = 3
41+
42+
atlas_5 = np.zeros((6, 6, 6), dtype=np.int16)
43+
for i in range(5):
44+
atlas_5[:, i, :] = i + 1
45+
# last column unlabeled (kept as 0) so label set is exactly {1..5}
46+
atlas_5[:, 5, :] = 0
47+
48+
bold_path = tmp_path / "bold.nii.gz"
49+
mask_path = tmp_path / "mask.nii.gz"
50+
atlas_3_path = tmp_path / "atlas3.nii.gz"
51+
atlas_5_path = tmp_path / "atlas5.nii.gz"
52+
_save_nifti(bold_path, bold)
53+
_save_nifti(mask_path, mask.astype(np.float64))
54+
_save_nifti(atlas_3_path, atlas_3.astype(np.float64))
55+
_save_nifti(atlas_5_path, atlas_5.astype(np.float64))
56+
57+
# Skip the scalar maps -- this test only cares about the atlas loop.
58+
from pathlib import Path as _Path
59+
60+
counter = {"n": 0}
61+
62+
def _next_scratch(name: str) -> _Path:
63+
counter["n"] += 1
64+
p = tmp_path / f"{name}_{counter['n']}.nii.gz"
65+
_save_nifti(p, np.zeros((6, 6, 6)))
66+
return p
67+
68+
def _scalar_pair(*_args: object, **kwargs: object) -> tuple[_Path, _Path]:
69+
out_file = kwargs.get("out_file")
70+
alff = (
71+
_Path(out_file) # type: ignore[arg-type]
72+
if out_file is not None
73+
else _next_scratch("alff")
74+
)
75+
if not alff.exists():
76+
_save_nifti(alff, np.zeros((6, 6, 6)))
77+
return alff, _next_scratch("falff")
78+
79+
def _scalar_single(*_args: object, **_kwargs: object) -> _Path:
80+
return _next_scratch("scalar")
81+
82+
def _smooth(in_path: _Path, _mask: _Path, **_kwargs: object) -> _Path:
83+
return in_path
84+
85+
monkeypatch.setattr(metrics_mod, "compute_alff", _scalar_pair)
86+
monkeypatch.setattr(metrics_mod, "compute_reho", _scalar_single)
87+
monkeypatch.setattr(metrics_mod, "smooth", _smooth)
88+
monkeypatch.setattr(metrics_mod, "compute_zscore", _scalar_single)
89+
90+
outputs = single_session_metrics(
91+
regressed_bold=bold_path,
92+
cleaned_bold=bold_path,
93+
template_brain_mask=mask_path,
94+
tr=2.0,
95+
atlas_files={"atl3": atlas_3_path, "atl5": atlas_5_path},
96+
fwhm=6.0,
97+
)
98+
99+
# Distinct files per atlas, never overwriting each other.
100+
assert outputs.timeseries["atl3"] != outputs.timeseries["atl5"]
101+
assert outputs.correlation_matrix["atl3"] != outputs.correlation_matrix["atl5"]
102+
103+
ts3 = np.loadtxt(outputs.timeseries["atl3"], delimiter="\t")
104+
ts5 = np.loadtxt(outputs.timeseries["atl5"], delimiter="\t")
105+
assert ts3.shape == (3, 8)
106+
assert ts5.shape == (5, 8)
107+
108+
corr3 = np.loadtxt(outputs.correlation_matrix["atl3"], delimiter="\t")
109+
corr5 = np.loadtxt(outputs.correlation_matrix["atl5"], delimiter="\t")
110+
assert corr3.shape == (3, 3)
111+
assert corr5.shape == (5, 5)

0 commit comments

Comments
 (0)