Skip to content
2 changes: 1 addition & 1 deletion src/rbc/bids/longitudinal/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def resolve_longitudinal_anat(
Dict with keys matching ``longitudinal_process`` parameters.
"""
return {
"template": tpl_q.expect(tpl_df, suffix=Suffix.T1W),
"template": tpl_q.expect(tpl_df, suffix=Suffix.T1W, without=["res"]),
"subj_to_template_xfm": tpl_q.expect(
tpl_df,
suffix="xfm",
Expand Down
4 changes: 3 additions & 1 deletion src/rbc/bids/longitudinal/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def resolve_longitudinal_func(
tpl_df: pl.DataFrame,
*,
ses: str,
task: str,
regressors: Sequence[str] = ("36-parameter",),
) -> dict[str, Path | dict[str, Path]]:
"""Resolve inputs for longitudinal functional processing.
Expand All @@ -33,6 +34,7 @@ def resolve_longitudinal_func(
func_df: DataFrame of functional derivatives.
tpl_df: DataFrame of longitudinal template files.
ses: Session label (used for template xfm lookup).
task: Task entity value to denote BOLD reference for template resampling.
regressors: Regressor strategy names to resolve raw regressor
files for.

Expand All @@ -51,7 +53,7 @@ def resolve_longitudinal_func(
)

return {
"template": tpl_q.expect(tpl_df, suffix="T1w"),
"template": tpl_q.expect(tpl_df, suffix="T1w", res=task),
"anat_to_template_xfm": tpl_q.expect(
tpl_df,
suffix="xfm",
Expand Down
37 changes: 35 additions & 2 deletions src/rbc/bids/longitudinal/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import polars as pl

from rbc.bids import Suffix, bids_safe_label
from rbc.bids import FUNC_GROUP_ENTITIES, Suffix, bids_safe_label

if TYPE_CHECKING:
from rbc.bids import Bids
Expand All @@ -21,11 +21,13 @@ class TemplateInputs(NamedTuple):
sub: Subject label.
sessions: Per-input session labels (parallel to ``files``).
files: Per-session preprocessed T1w brain volumes.
bold_ref: Per-session preprocessed BOLD volumes (all tasks).
Comment thread
kaitj marked this conversation as resolved.
Outdated
"""

sub: str
sessions: list[str]
files: list[Path]
bold_files: dict[str, Path]


def discover_template_inputs(
Expand Down Expand Up @@ -56,6 +58,13 @@ def discover_template_inputs(
# the mri_robust_template invocation.
pl.col("space").is_null(),
)
bold_rows = df.filter(
pl.col("ses") != "longitudinal",
pl.col("datatype") == "func",
pl.col("desc") == "preproc",
pl.col("suffix") == "bold",
pl.col("space").is_null(),
)

inputs: list[TemplateInputs] = []
skipped: list[str] = []
Expand All @@ -68,7 +77,29 @@ def discover_template_inputs(
files = [
Path(row["root"]) / row["path"] for row in sub_group.iter_rows(named=True)
]
inputs.append(TemplateInputs(sub=sub, sessions=sessions, files=files))
sub_bold = bold_rows.filter(
(pl.col("sub") == sub) & (pl.col("ses") == sessions[0])
).unique(subset=(*FUNC_GROUP_ENTITIES, "root", "path"))
# Check each task is unique, otherwise raise assertion error with details
if sub_bold.height != sub_bold.unique().height:
Comment thread
kaitj marked this conversation as resolved.
Outdated
conflicts = (
sub_bold.filter(pl.struct(FUNC_GROUP_ENTITIES).is_duplicated())
.group_by(FUNC_GROUP_ENTITIES)
.agg(pl.format("{}/{}", "root", "path").alias("paths"))
)
raise AssertionError(
f"Found multiple non-matching grids for subject {sub}:\n"
+ "\n".join(str(dict(row)) for row in conflicts.iter_rows(named=True))
)
bold_files = {
row["task"]: Path(row["root"]) / row["path"]
Comment thread
kaitj marked this conversation as resolved.
Outdated
for row in sub_bold.iter_rows(named=True)
}
Comment thread
kaitj marked this conversation as resolved.
Outdated
inputs.append(
TemplateInputs(
sub=sub, sessions=sessions, files=files, bold_files=bold_files
)
)
return inputs, skipped


Expand All @@ -81,6 +112,8 @@ def export_template(tpl: Bids, outputs: LongitudinalTemplateOutputs) -> None:
outputs: Results from the longitudinal template workflow.
"""
tpl.save(outputs.template, suffix=Suffix.T1W)
for btask, bold_template in outputs.bold_templates.items():
tpl.save(bold_template, res=btask, suffix=Suffix.T1W)
Comment thread
kaitj marked this conversation as resolved.
Outdated
for ses, xfm in zip(outputs.sessions, outputs.transforms, strict=True):
ses_label = bids_safe_label(ses)
tpl.save(
Expand Down
42 changes: 42 additions & 0 deletions src/rbc/core/longitudinal/resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Resampling utilities for longitudinal templates."""

from __future__ import annotations

from typing import TYPE_CHECKING

import nibabel as nib
from nibabel.processing import resample_from_to

if TYPE_CHECKING:
from pathlib import Path

from rbc.core.niwrap import generate_exec_folder


def resample_img_to_bold_grid(bold_ref: Path, img: Path, order: int = 3) -> Path:
"""Resample template to BOLD grid if shapes differ.

Args:
bold_ref: BOLD reference volume.
img: 3D image in target space to resample.
order: Interpolation order used during resampling

Returns:
Resampled 3D image with BOLD grid
"""
bold_ref_img = nib.nifti1.load(bold_ref)
img_obj = nib.nifti1.load(img)

# If 4D, extract first volume
if len(bold_ref_img.shape) > 3:
bold_ref_img = bold_ref_img.slicer[..., 0]
# If same shape, no need to resample
if bold_ref_img.shape == img_obj.shape:
return img
Comment thread
kaitj marked this conversation as resolved.
Outdated

img_resampled = resample_from_to(img_obj, bold_ref_img, order=order)
img_resampled_path = (
generate_exec_folder("img_resample_to_bold_grid") / "resampled.nii.gz"
)
nib.save(img_resampled, img_resampled_path)
return img_resampled_path
4 changes: 2 additions & 2 deletions src/rbc/orchestration/longitudinal/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rbc.bids import FUNC_GROUP_ENTITIES, Datatype, Suffix, extract_entities, load_table
from rbc.bids.longitudinal.template import discover_template_inputs
from rbc.bids.metrics import export_metrics
from rbc.bids.session import iter_session_files
from rbc.bids.session import _FUNC_ENTITY_KEYS, iter_session_files
from rbc.context import RunContext
from rbc.orchestration import Filters, RunnerConfig, init_runner
from rbc.orchestration.longitudinal._iter import iter_sessions_with_template
Expand Down Expand Up @@ -124,7 +124,7 @@ def run(
)

row = func_df.filter(suffix=Suffix.BOLD).row(0, named=True)
ents = extract_entities(row, ["task", "run"])
ents = extract_entities(row, _FUNC_ENTITY_KEYS)
func_q = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents)
func_long = func_q.derive(space="longitudinal")

Expand Down
2 changes: 1 addition & 1 deletion src/rbc/orchestration/longitudinal/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def process_anat(
Workflow outputs for in-memory handoff to downstream stages.
"""
anat_df = anat_df.filter(pl.col("space").is_null())
ents = extract_entities(anat_df.row(0, named=True), ["run"])
ents = extract_entities(anat_df.row(0, named=True), ["run", "acq", "rec", "echo"])

anat_q = pipe_ctx.bids(datatype=Datatype.ANAT)
tpl_q = anat_q.derive(ses="longitudinal")
Expand Down
5 changes: 3 additions & 2 deletions src/rbc/orchestration/longitudinal/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
export_longitudinal_func,
resolve_longitudinal_func,
)
from rbc.bids.session import iter_session_files
from rbc.bids.session import _FUNC_ENTITY_KEYS, iter_session_files
from rbc.orchestration import Filters, RunnerConfig, init_runner
from rbc.orchestration.longitudinal._iter import iter_sessions_with_template
from rbc.workflows.longitudinal.functional import (
Expand Down Expand Up @@ -53,7 +53,7 @@ def process_func(
Workflow outputs for in-memory handoff to downstream stages.
"""
row = func_df.filter(suffix=Suffix.BOLD).row(0, named=True)
ents = extract_entities(row, ["task", "run"])
ents = extract_entities(row, list(_FUNC_ENTITY_KEYS))

func_q = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents)
tpl_q = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")
Expand All @@ -64,6 +64,7 @@ def process_func(
func_df,
tpl_df,
ses=pipe_ctx.ses, # type: ignore[arg-type]
task=ents["task"],
regressors=regressors,
)
func_outputs = functional_longitudinal(**resolved) # type: ignore[arg-type]
Expand Down
5 changes: 5 additions & 0 deletions src/rbc/orchestration/longitudinal/qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
write_longitudinal_qc_tsv,
)
from rbc.bids.session import iter_session_files
from rbc.core.longitudinal.resampling import resample_img_to_bold_grid
from rbc.core.niwrap import generate_exec_folder
from rbc.core.qc.registration import registration_qc_metrics
from rbc.orchestration import Filters, RunnerConfig, init_runner
Expand Down Expand Up @@ -67,6 +68,10 @@ def process_qc(
Returns:
QC outputs with overlap metrics and pass/fail flag.
"""
# Resample longitudinal anatomical mask to bold grid for QC purposes.
# Longitudinal processed data are registered to the longitudinal template with
# respective modality's native resolution
anat_brain_mask = resample_img_to_bold_grid(bold_mask, anat_brain_mask, order=0)
anat_mask_arr = nib.nifti1.load(anat_brain_mask).get_fdata()
bold_mask_arr = nib.nifti1.load(bold_mask).get_fdata()
reg_metrics = registration_qc_metrics(anat_mask_arr, bold_mask_arr)
Expand Down
1 change: 1 addition & 0 deletions src/rbc/orchestration/longitudinal/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def process_subject(
sub=inputs.sub,
sessions=inputs.sessions,
in_files=inputs.files,
bold_files=inputs.bold_files,
)
tpl = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")
export_template(tpl, outputs)
Expand Down
15 changes: 14 additions & 1 deletion src/rbc/workflows/longitudinal/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
fs_to_itk_xfm,
generate_robust_template,
)
from rbc.core.longitudinal.resampling import resample_img_to_bold_grid

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from pathlib import Path

_logger = logging.getLogger("rbc")
Expand All @@ -27,11 +28,14 @@ class LongitudinalTemplateOutputs(NamedTuple):

Attributes:
template: Robust within-subject template volume.
bold_templates: Within-subject template volumes resampled to task-specific BOLD
resolutions.
sessions: Session labels in the same order as ``transforms``.
transforms: Per-session ITK-format session-to-template transforms.
"""

template: Path
bold_templates: dict[str, Path]
sessions: list[str]
transforms: list[Path]

Expand All @@ -40,13 +44,15 @@ def generate_subject_template(
sub: str,
sessions: Sequence[str],
in_files: Sequence[Path],
bold_files: Mapping[str, Path],
) -> LongitudinalTemplateOutputs:
"""Build a robust template and ITK transforms for one subject.

Args:
sub: Subject label (without the ``sub-`` prefix).
sessions: Session labels parallel to ``in_files``.
in_files: Per-session preprocessed T1w volumes (e.g. brain-extracted).
bold_files: Reference bold volumes to resample template for functional data.

Returns:
:class:`LongitudinalTemplateOutputs` ready for BIDS export.
Expand All @@ -65,8 +71,15 @@ def generate_subject_template(
in_xfms=robust.transforms,
)

_logger.info("Creating reference volumes for each functional task")
bold_templates = {
btask: resample_img_to_bold_grid(bfile, robust.template)
for btask, bfile in bold_files.items()
}

return LongitudinalTemplateOutputs(
template=robust.template,
bold_templates=bold_templates,
sessions=list(sessions),
transforms=itk_xfms,
)
1 change: 1 addition & 0 deletions tests/unit/bids/test_longitudinal_anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _anat_row(
"sub": sub,
"ses": ses,
"desc": desc,
"res": None,
"root": "/data",
"path": path,
"extra_entities": extra or [],
Expand Down
Loading
Loading