Skip to content

Commit 6df44d7

Browse files
committed
Refactor longitudinal script to depend on rbc
1 parent fe3c61a commit 6df44d7

1 file changed

Lines changed: 77 additions & 182 deletions

File tree

scripts/build_robust_template.py

Lines changed: 77 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# /// script
22
# dependencies = [
3-
# "bids2table>=2.1.2",
43
# "niwrap>=0.9.1",
54
# "polars>=1.38.1",
6-
# "styxpodman",
5+
# "rbc",
76
# "tqdm>=4.67.3",
87
# ]
9-
# requires-python = ">=3.11"
8+
# requires-python = ">=3.12"
109
#
1110
# [tool.uv.sources]
12-
# styxpodman = { git = "https://github.com/styx-api/styxpodman", rev = "1382977" }
11+
# rbc = { git = "https://github.com/childmindresearch/rbc-mirror" }
1312
#
1413
# ///
1514
"""Generate a robust, longitudinal T1w template using Freesurfer's mri_robust_template.
@@ -21,33 +20,25 @@
2120
from __future__ import annotations
2221

2322
import argparse
24-
import logging
23+
from dataclasses import dataclass
2524
import os
26-
import shutil
27-
import tempfile
2825
from pathlib import Path
2926
from typing import TYPE_CHECKING, NamedTuple
3027

31-
import bids2table as b2t
3228
import polars as pl
33-
from niwrap import (
34-
c3d, # uv script dependency with private repo; using c3d to convert transform
35-
Runner,
36-
freesurfer,
37-
get_global_runner,
38-
set_global_runner,
39-
use_docker,
40-
use_local,
41-
use_singularity,
42-
)
43-
from styxpodman import PodmanRunner
29+
from niwrap import c3d, Runner, freesurfer
30+
from rbc.cli import _DEFAULT_ENV_VARS
31+
from rbc.cli.base import BaseArgs
32+
from rbc.cli.main import _global_opts
33+
from rbc.context import PipelineContext
34+
from rbc.core.bids import Datatype, Suffix
35+
from rbc.core.bids2table import load_table
36+
from rbc.core.niwrap import setup_runner
4437
from tqdm import tqdm
4538

4639
if TYPE_CHECKING:
4740
from collections.abc import Sequence
48-
from typing import Literal
4941

50-
_LOG_LEVELS = [logging.WARNING, logging.INFO, logging.DEBUG]
5142
CONTAINER_LICENSE_PATH = "/opt/freesurfer/license.txt"
5243

5344

@@ -57,14 +48,7 @@ def create_parser() -> argparse.ArgumentParser:
5748
prog="create_template",
5849
description="Create a robust template using Freesurfer's mri_robust_template",
5950
formatter_class=argparse.RawDescriptionHelpFormatter,
60-
usage="%(prog)s in_files [in_files...] output_file [options]",
61-
)
62-
parser.add_argument(
63-
"-v",
64-
"--verbose",
65-
action="count",
66-
default=0,
67-
help="Increase verbosity (can be repeated: -v, -vv, -vvv)",
51+
usage="%(prog)s input_dir output_dir [options]",
6852
)
6953
parser.add_argument(
7054
"input_dir",
@@ -76,103 +60,17 @@ def create_parser() -> argparse.ArgumentParser:
7660
type=Path,
7761
help="Directory where output data should be stored",
7862
)
63+
_global_opts()
7964
parser.add_argument(
8065
"--fs-license",
8166
required=False,
8267
type=Path,
8368
help="Path to Freesurfer license",
8469
)
85-
parser.add_argument(
86-
"--participant-label",
87-
nargs="+",
88-
default=[],
89-
type=lambda x: x.removeprefix("sub-"),
90-
help="Space-delimited participant identifier ('sub-' prefix can be removed)",
91-
)
92-
parser.add_argument(
93-
"--runner",
94-
choices=["local", "docker", "podman", "singularity"],
95-
default="local",
96-
type=lambda x: x.lower(),
97-
help="NiWrap runner to use for executing workflow",
98-
)
99-
10070
return parser
10171

102-
103-
class StyxContext(NamedTuple):
104-
"""Styx execution context with logger and runner."""
105-
106-
logger: logging.Logger
107-
runner: Runner
108-
verbose: bool
109-
110-
111-
def setup_runner(
112-
runner: Literal["local", "docker", "podman", "singularity"] = "local",
113-
tmp_dir: str | Path | None = None,
114-
image_overrides: dict[str, str] | None = None,
115-
verbose: int = 0,
116-
**kwargs, # noqa: ANN003 (ignore annotation for kwargs)
117-
) -> StyxContext:
118-
"""Setup Styx with appropriate runner for NiWrap.
119-
120-
Args:
121-
runner: Type of runner to use - choices include
122-
['local', 'docker', 'podman', 'singularity']
123-
tmp_dir: Working directory to output to
124-
image_overrides: Dictionary containing overrides for container tags.
125-
verbose: Verbosity level (0=WARNING, 1=INFO, 2+=DEBUG)
126-
**kwargs: Additional keyword arguments passed for runner setup.
127-
128-
Returns:
129-
Configured logger instance and initialized runner
130-
"""
131-
match runner_exec := runner.lower():
132-
case "local":
133-
use_local()
134-
case "docker":
135-
use_docker(
136-
docker_executable=runner_exec,
137-
image_overrides=image_overrides,
138-
docker_user_id=0,
139-
**kwargs,
140-
)
141-
case "podman":
142-
set_global_runner(
143-
runner=PodmanRunner(
144-
podman_executable=runner_exec,
145-
image_overrides=image_overrides,
146-
podman_user_id=0,
147-
**kwargs,
148-
)
149-
)
150-
case "singularity":
151-
use_singularity(
152-
singularity_executable=runner_exec,
153-
image_overrides=image_overrides,
154-
**kwargs,
155-
)
156-
case _:
157-
raise NotImplementedError(
158-
f"Unknown runner selection '{runner}' - please select one of "
159-
"'local', 'docker', or 'singularity'"
160-
)
161-
162-
styx_runner = get_global_runner()
163-
if tmp_dir is None:
164-
tmp_dir = Path(tempfile.gettempdir()) / f"robust_template_{os.urandom(8).hex()}"
165-
tmp_dir.mkdir(exist_ok=False, parents=True)
166-
styx_runner.data_dir = tmp_dir
167-
logger = logging.getLogger(styx_runner.logger_name)
168-
log_level = min(verbose, len(_LOG_LEVELS) - 1)
169-
logger.setLevel(_LOG_LEVELS[log_level])
170-
return StyxContext(logger=logger, runner=styx_runner, verbose=verbose > 0)
171-
172-
173-
def _get_mount_arg(runner: str, host_path: Path) -> list[str]:
72+
def _get_mount_arg(runner: str, src: str, dst: str) -> list[str]:
17473
"""Return runner-specific mount CLI args."""
175-
src, dst = str(host_path), CONTAINER_LICENSE_PATH
17674
if runner in ("podman", "docker"):
17775
return ["--mount", f"type=bind,source={src},target={dst},readonly"]
17876
return ["--bind", f"{src}:{dst}"] # singularity
@@ -181,14 +79,15 @@ def _get_mount_arg(runner: str, host_path: Path) -> list[str]:
18179
def mount_fs_license(runner: Runner, fs_license: str) -> None:
18280
"""Mount FreeSurfer license file into an existing runner."""
18381
runner_name = type(runner).__name__.lower().replace("runner", "")
184-
license_path = Path(fs_license).resolve()
18582

18683
if runner_name == "local":
187-
os.environ["FS_LICENSE"] = str(license_path)
84+
os.environ["FS_LICENSE"] = fs_license
18885
return
18986

19087
extra_args_attr = f"{runner_name}_extra_args"
191-
getattr(runner, extra_args_attr).extend(_get_mount_arg(runner_name, license_path))
88+
getattr(runner, extra_args_attr).extend(
89+
_get_mount_arg(runner=runner_name, src=fs_license, dst=CONTAINER_LICENSE_PATH)
90+
)
19291
runner.environ["FS_LICENSE"] = CONTAINER_LICENSE_PATH
19392

19493

@@ -211,26 +110,16 @@ def generate_robust_template(in_files: Sequence[Path]) -> RobustTemplateOutputs:
211110
"""
212111
lta_files = []
213112
entities = None
214-
for in_file in in_files:
215-
if not Path(in_file).exists():
216-
raise FileNotFoundError(f"{in_file} not found.")
217-
entities = b2t.parse_bids_entities(in_file)
218-
lta_fname = b2t.format_bids_path(
219-
{
220-
"sub": entities["sub"],
221-
"ses": "longitudinal",
222-
"from": entities["ses"],
223-
"suffix": "xfm",
224-
"ext": ".lta",
225-
}
226-
).name
227-
lta_files.append(lta_fname)
113+
for idx, in_file in enumerate(in_files):
114+
if not in_file.exists():
115+
raise FileNotFoundError(f"Input file not found: {in_file}.")
116+
lta_files.append(f"xfm_{idx:04d}.lta")
228117

229118
# Initialize with same defaults as fmriprep
230119
assert entities is not None, "No entities found"
231120
robust_template = freesurfer.mri_robust_template(
232121
mov=list(in_files),
233-
template=f"sub-{entities['sub']}_ses-longitudinal_T1w.nii.gz",
122+
template=f"long_template_T1w.nii.gz",
234123
lta=lta_files,
235124
inittp=1, # map everything to first time point
236125
fixtp=True,
@@ -277,59 +166,63 @@ def fs_to_ants_xfm(
277166
return result.transforms
278167

279168

280-
if __name__ == "__main__":
281-
parser = create_parser()
282-
args = parser.parse_args()
169+
@dataclass(frozen=True)
170+
class TemplateArgs(BaseArgs):
171+
"""Arguments for template-building CLI."""
172+
173+
fs_license: str
174+
175+
@classmethod
176+
def validate_namespace(cls, ns: argparse.Namespace) -> TemplateArgs:
177+
"""Validation of template-building script specific arguments to NamedTuple."""
178+
fs_license = ns.fs_license or os.getenv("FS_LICENSE")
179+
if fs_license is None or not Path(fs_license).exists():
180+
raise ValueError(f"FreeSurfer license file not found: {fs_license}")
181+
return cls(
182+
**BaseArgs.validate_namespace(ns).__dict__,
183+
fs_license=str(fs_license)
184+
)
283185

186+
def process(args: TemplateArgs) -> int:
187+
"""Main processing layer for script."""
284188
# 1. Setup
285-
fs_license = args.fs_license or os.getenv("FS_LICENSE")
286-
if fs_license is None or not Path(fs_license).exists():
287-
raise FileNotFoundError(f"Freesurfer license not found: {fs_license}")
288-
ctx = setup_runner(runner=args.runner, verbose=args.verbose)
289-
# Taken from rbc's _DEFAULT_ENVS (uses CPAC ANTs seed)
290-
ctx.runner.environ = {
291-
"ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS": "1",
292-
"ANTS_RANDOM_SEED": 77742777,
293-
"FSLOUTPUTTYPE": "NIFTI_GZ" # Needed for FreeSurfer xfm conversion
294-
}
189+
ctx = setup_runner(runner=args.runner, verbose=args.verbose, tmp_dir=args.tmp_dir)
190+
ctx.runner.environ = {**_DEFAULT_ENV_VARS, "FSLOUTPUTTYPE": "NIFTI_GZ"}
295191
ctx.logger.warning(
296-
"This script is experimental and may be sensitive to input file naming "
297-
"conventions."
192+
"This script is experimental and may be sensitive to file naming conventions."
298193
)
299-
mount_fs_license(ctx.runner, fs_license)
194+
mount_fs_license(ctx.runner, args.fs_license)
300195

301196
ctx.logger.info("Preparing to generate longitudinal templates")
302-
tables = b2t.batch_index_dataset(
303-
b2t.find_bids_datasets(args.input_dir),
304-
max_workers=0,
305-
show_progress=ctx.verbose,
197+
df = load_table(
198+
dataset_dir=args.input_dir, index_fpath=None, max_workers=0, verbose=ctx.verbose
306199
)
307-
dfs: list[pl.DataFrame] = []
308-
for table in tables:
309-
result = pl.from_arrow(table)
310-
if not isinstance(result, pl.DataFrame):
311-
raise TypeError(f"Expected DataFrame, got {type(result)}")
312-
dfs.append(result)
313-
df = pl.concat(dfs)
314-
# Filters for preprocessed T1w to create longitudinal template
200+
315201
filters = [
316202
pl.col("ses") != "longitudinal",
317203
pl.col("datatype") == "anat",
204+
pl.col("space").is_null(),
318205
pl.col("desc") == "brain",
319206
pl.col("suffix") == "T1w",
320207
]
321208
if len(args.participant_label) > 0:
322209
filters.append(pl.col("sub").is_in(args.participant_label))
210+
if len(args.session_label) > 0:
211+
filters.append(pl.col("ses").is_in(args.session_label))
323212
df = df.filter(pl.all_horizontal(filters))
324-
del dfs
325213

326-
ctx.logger.info("Starting processing")
327-
if len(df) == 1:
328-
raise ValueError("Only a single volume found")
329-
for _, sub_group in tqdm(df.group_by("sub"), disable=not ctx.verbose):
214+
for _, sub_group in tqdm(
215+
df.group_by(("sub"), maintain_order=True), disable=not ctx.verbose
216+
):
217+
sessions = sub_group["ses"].to_list()
218+
pipe_ctx = PipelineContext(
219+
sub=sub_group["sub"][0], ses=None, output_dir=args.output_dir
220+
)
221+
330222
# 2. Construct template
331-
sub = sub_group["sub"][0]
332-
ctx.logger.info(f"Building robust template for sub-{sub}")
223+
ctx.logger.info(f"Building robust template for subject: {pipe_ctx.sub}")
224+
if len(sub_group) <= 1:
225+
raise ValueError("At least 2 volumes needed to generate a template.")
333226
in_files = [
334227
Path(row["root"]) / row["path"] for row in sub_group.iter_rows(named=True)
335228
]
@@ -344,15 +237,17 @@ def fs_to_ants_xfm(
344237
)
345238

346239
# 4. Save outputs
347-
ctx.logger.info("Saving files")
348-
output_dir = (
349-
Path(args.output_dir)
350-
/ b2t.format_bids_path(
351-
{"sub": sub, "ses": "longitudinal", "datatype": "anat"}
352-
).parent
353-
)
354-
output_dir.mkdir(exist_ok=True, parents=True)
355-
for fpath in [robust_template.template, *subj_to_temp]:
356-
shutil.copy2(fpath, output_dir)
357-
ctx.logger.info("Robust template creation complete")
358-
ctx.logger.info("Completed creating all templates.")
240+
long = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")
241+
long.save(robust_template.template, suffix=Suffix.T1W, desc="brain")
242+
for idx, fpath in enumerate(subj_to_temp):
243+
long.save(fpath, session=sessions[idx], suffix="xfm", extension=".lta")
244+
245+
ctx.logger.info("Robust template creation complete")
246+
return 0
247+
248+
249+
if __name__ == "__main__":
250+
parser = create_parser()
251+
args = parser.parse_args()
252+
253+
process(TemplateArgs.validate_namespace(args))

0 commit comments

Comments
 (0)