Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion src/rbc/core/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

Provides :class:`Volume` for type-safe loading, deriving, and saving of NIfTI
images, plus lightweight metadata queries (:func:`nifti_num_volumes`,
:func:`nifti_num_slices`) that avoid loading full image data.
:func:`nifti_num_slices`, :func:`log_image_summary`) that avoid loading full
image data.
"""

from __future__ import annotations

import logging
import warnings
from enum import Enum, IntEnum
from pathlib import Path
Expand All @@ -24,11 +26,14 @@
"Space",
"Units",
"Volume",
"log_image_summary",
"nifti_num_slices",
"nifti_num_volumes",
"strip_afni_volatile_metadata",
]

_logger = logging.getLogger(__name__)

# AFNI embeds a NIfTI extension (code 4) with an XML payload that contains
# wall-clock timestamps and a random per-invocation UUID. Those poison
# content-hash based caching for any tool downstream. Drop the extension
Expand Down Expand Up @@ -533,6 +538,101 @@ def nifti_num_slices(in_file: str | Path) -> int:
return img.shape[2] if len(img.shape) >= 3 else 1


def _space_label(code: int) -> str:
"""Human-readable name for a NIfTI sform/qform code, or the raw int."""
try:
return Space(code).name
except ValueError:
return str(code)


def _human_bytes(n: int) -> str:
"""Format a byte count with a binary (B/KiB/MiB/GiB) suffix."""
size = float(n)
for unit in ("B", "KiB", "MiB"):
if size < 1024:
return f"{size:.0f} {unit}" if unit == "B" else f"{size:.1f} {unit}"
size /= 1024
return f"{size:.1f} GiB"


def log_image_summary(in_file: str | Path, *, label: str = "Raw input") -> None:
"""Log array shape, dtype, and geometry of a raw NIfTI input.

Reads only the NIfTI header (no voxel data is loaded), then emits an
INFO-level summary so the run log records exactly what entered the
pipeline: array shape, on-disk dtype, data size (``shape`` x dtype
itemsize), voxel size, axis orientation, sform/qform coordinate spaces,
and (for 4D+ images) volume count, slice axis/count, slice acquisition
order, and TR.

This is best-effort diagnostics only: a header that cannot be read
produces a warning, not an exception, so the real failure surfaces
later when processing actually touches the file.

Args:
in_file: Path to a ``.nii``/``.nii.gz`` file.
label: Short prefix identifying the input in the log (e.g.
``"Anatomical T1w"``).
"""
path = Path(in_file)
try:
img = nib.nifti1.load(path)
hdr = img.header
shape = img.shape
dtype = hdr.get_data_dtype()
zooms = hdr.get_zooms()
spatial_unit = hdr.get_xyzt_units()[0]

n_bytes = int(np.prod(shape, dtype=np.int64)) * dtype.itemsize
voxel = " x ".join(f"{z:.3g}" for z in zooms[:3])
voxel += f" {spatial_unit}" if spatial_unit != "unknown" else " (units unknown)"
orientation = "".join(nib.aff2axcodes(img.affine))

_logger.info("%s: %s", label, path)
_logger.info(
"%s: shape=%s, dtype=%s, size=%s, voxel size=%s",
label,
shape,
dtype,
_human_bytes(n_bytes),
voxel,
)
_logger.info(
"%s: orientation=%s, sform=%s, qform=%s",
label,
orientation,
_space_label(int(hdr["sform_code"])),
_space_label(int(hdr["qform_code"])),
)

if len(shape) > 3:
raw_tr = float(hdr["pixdim"][4])
tr = f"{raw_tr:.4g} s" if raw_tr > 0 else "unknown"
# dim_info names the slice axis; BOLD usually omits it, so fall
# back to the conventional third axis.
slice_axis = hdr.get_dim_info()[2]
if slice_axis is not None:
n_slices, axis_desc = shape[slice_axis], f"axis {slice_axis}"
else:
n_slices, axis_desc = shape[2], "axis 2 (assumed; no dim_info)"
slice_order = hdr.get_value_label("slice_code") # "unknown" if unset
extra = f", extra dims={tuple(shape[4:])}" if len(shape) > 4 else ""
_logger.info(
"%s: volumes=%d, slices=%d along %s, slice order=%s%s, header TR=%s",
label,
shape[3],
n_slices,
axis_desc,
slice_order,
extra,
tr,
)
except Exception as exc:
# Diagnostics must never abort a run; the real failure surfaces later.
_logger.warning("%s: could not read NIfTI header for %s (%s)", label, path, exc)


def strip_afni_volatile_metadata(path: str | Path) -> None:
"""Rewrite a NIfTI file with AFNI's non-deterministic extension removed.

Expand Down
3 changes: 2 additions & 1 deletion src/rbc/orchestration/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rbc.bids.anatomical import discover_anatomical, export_anatomical
from rbc.bids.session import load_session
from rbc.context import RunContext
from rbc.core.nifti import log_image_summary
from rbc.orchestration import Filters, RunnerConfig, init_runner
from rbc.workflows.anatomical import AnatomicalOutputs, single_session_preprocess
from rbc_resources import (
Expand Down Expand Up @@ -48,7 +49,7 @@ def process_session(
"""
outputs: AnatomicalOutputs | None = None
for anat_run in discover_anatomical(session):
_logger.info("Anatomical: %s", anat_run.path)
log_image_summary(anat_run.path, label="Anatomical T1w")
outputs = single_session_preprocess(
in_t1w=anat_run.path,
brain_extraction_templates=brain_extraction_templates,
Expand Down
3 changes: 2 additions & 1 deletion src/rbc/orchestration/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from rbc.bids.session import load_session
from rbc.context import RunContext
from rbc.core.nifti import log_image_summary
from rbc.metadata import FunctionalMetadata
from rbc.orchestration import Filters, RunnerConfig, init_runner
from rbc.workflows.functional import single_session_preprocess
Expand Down Expand Up @@ -64,7 +65,7 @@ def process_session(
"""
results = []
for func_run in discover_functional(session):
_logger.info("Functional: %s", func_run.path)
log_image_summary(func_run.path, label="Functional BOLD")

if anat_inputs is not None:
resolved = anat_inputs
Expand Down
1 change: 1 addition & 0 deletions tests/unit/orchestration/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _patch_process_session() -> Generator[tuple[Mock, Mock, FunctionalRun], None
"rbc.orchestration.functional.discover_functional",
return_value=[func_run],
),
patch("rbc.orchestration.functional.log_image_summary"),
patch(
"rbc.orchestration.functional.resolve_functional",
return_value=_ANAT_INPUTS,
Expand Down
122 changes: 122 additions & 0 deletions tests/unit/test_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import nibabel as nib
Expand All @@ -12,6 +13,7 @@
Space,
Units,
Volume,
log_image_summary,
nifti_num_slices,
nifti_num_volumes,
)
Expand Down Expand Up @@ -626,3 +628,123 @@ def test_nifti_num_volumes_3d(self, nifti_3d: Path) -> None:
def test_nifti_num_slices_3d(self, nifti_3d: Path) -> None:
"""3D image reports correct slice count."""
assert nifti_num_slices(nifti_3d) == 7


class TestLogImageSummary:
"""Tests for log_image_summary()."""

def test_3d_summary(self, nifti_3d: Path, caplog: pytest.LogCaptureFixture) -> None:
"""3D input logs shape, dtype, size, voxel size, orientation, spaces."""
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(nifti_3d, label="Anatomical T1w")
text = "\n".join(caplog.messages)
assert "Anatomical T1w" in text
assert "shape=(5, 6, 7)" in text
assert "dtype=float64" in text
assert "size=1.6 KiB" in text # 5*6*7 * 8 = 1680 bytes
assert "voxel size=1 x 1 x 1 mm" in text
assert "orientation=RAS" in text
assert "sform=MNI" in text
assert "qform=MNI" in text

def test_3d_summary_omits_4d_fields(
self, nifti_3d: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""3D input does not log volume/slice/TR fields."""
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(nifti_3d)
assert not any("volumes=" in m for m in caplog.messages)
assert not any("TR=" in m for m in caplog.messages)

def test_4d_summary_includes_volumes_and_tr(
self, nifti_4d: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""4D input logs volume count, slice axis/count/order, and header TR."""
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(nifti_4d, label="Functional BOLD")
text = "\n".join(caplog.messages)
assert "shape=(5, 6, 7, 10)" in text
assert "size=16.4 KiB" in text # 5*6*7*10 * 8 bytes
assert "volumes=10" in text
assert "slices=7 along axis 2 (assumed; no dim_info)" in text
assert "slice order=unknown" in text
assert "header TR=2 s" in text
assert "extra dims" not in text

def test_4d_slice_axis_and_order_from_header(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""A header that sets dim_info / slice_code has them reported."""
rng = np.random.default_rng(0)
img = nib.Nifti1Image(rng.standard_normal((4, 5, 6, 3)), np.eye(4))
hdr = img.header
hdr.set_dim_info(slice=1)
hdr["slice_code"] = 1 # sequential increasing
pixdim = hdr["pixdim"].copy()
pixdim[4] = 2.0
hdr["pixdim"] = pixdim
path = tmp_path / "slices.nii.gz"
img.to_filename(str(path))

caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(path)
text = "\n".join(caplog.messages)
assert "slices=5 along axis 1" in text
assert "slice order=sequential increasing" in text

def test_5d_reports_extra_dims(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""5D input reports the trailing dims rather than mislabeling them."""
path = _make_nifti(tmp_path, "multi.nii.gz", (4, 5, 6, 7, 2))
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(path)
assert any("extra dims=(2,)" in m for m in caplog.messages)

def test_dtype_and_size_reflect_on_disk_type(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Logged dtype/size use the on-disk dtype, not float64 get_fdata()."""
path = _make_nifti(tmp_path, "int16.nii.gz", (4, 5, 6), dtype=np.int16)
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(path)
text = "\n".join(caplog.messages)
assert "dtype=int16" in text
assert "size=240 B" in text # 4*5*6 * 2 bytes

def test_size_uses_binary_units(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Data size scales to binary units."""
path = _make_nifti(tmp_path, "big.nii.gz", (64, 64, 64), dtype=np.int16)
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(path)
assert any("size=512.0 KiB" in m for m in caplog.messages)

def test_unknown_units_flagged(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Voxel size notes when spatial units are unset in the header."""
path = _make_nifti(tmp_path, "nounit.nii.gz", (4, 5, 6), xyzt_units=0)
caplog.set_level(logging.INFO, logger="rbc.core.nifti")
log_image_summary(path)
assert any("voxel size=1 x 1 x 1 (units unknown)" in m for m in caplog.messages)

def test_emitted_at_info_level(
self, nifti_3d: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Summary is emitted at INFO level (suppressed by default)."""
caplog.set_level(logging.WARNING, logger="rbc.core.nifti")
log_image_summary(nifti_3d)
assert caplog.messages == []

def test_unreadable_file_warns_without_raising(
self, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""A missing/corrupt file logs a warning instead of aborting the run."""
caplog.set_level(logging.WARNING, logger="rbc.core.nifti")
log_image_summary(tmp_path / "does_not_exist.nii.gz", label="Anatomical T1w")
assert any(
"could not read NIfTI header" in m and m.startswith("Anatomical T1w")
for m in caplog.messages
)
Loading