Skip to content

Commit c5dbe2a

Browse files
authored
Fix dataset indexing and session loading in cli (#231)
1 parent e41d111 commit c5dbe2a

8 files changed

Lines changed: 145 additions & 77 deletions

File tree

src/rbc/cli/all.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
import polars as pl
1515
from tqdm import tqdm
1616

17-
from rbc.cli import _DEFAULT_ENV_VARS, _FUNC_GROUP_ENTITIES, _SUB_SES_QUERY
17+
from rbc.cli import (
18+
_ANAT_GROUP_ENTITIES,
19+
_DEFAULT_ENV_VARS,
20+
_FUNC_GROUP_ENTITIES,
21+
_SUB_SES_QUERY,
22+
)
1823
from rbc.cli.base import BaseArgs, _validate_atlas, _validate_positive, _validate_task
1924
from rbc.cli.query import iter_session_files, load_session
2025
from rbc.context import PipelineContext
@@ -77,17 +82,22 @@ def main(args: AllArgs) -> int: # noqa: C901
7782
dataset_dir=args.input_dir, index_fpath=None, max_workers=0, verbose=ctx.verbose
7883
)
7984

80-
filters = [pl.col("space").is_null(), pl.col("desc").is_null()]
85+
filters = [
86+
pl.col("ses") != "longitudinal",
87+
pl.col("space").is_null(),
88+
pl.col("desc").is_null(),
89+
]
8190
if len(args.participant_label) > 0:
8291
filters.append(pl.col("sub").is_in(args.participant_label))
8392
if len(args.session_label) > 0:
8493
filters.append(pl.col("ses").is_in(args.session_label))
8594
if args.task is not None:
8695
filters.append(pl.col("task") == args.task)
87-
if filters:
88-
df = df.filter(pl.all_horizontal(filters))
96+
df = df.filter(pl.all_horizontal(filters))
8997

90-
for _, sub_ses_group in tqdm(df.group_by(_SUB_SES_QUERY), disable=not ctx.verbose):
98+
for _, sub_ses_group in tqdm(
99+
df.group_by(_SUB_SES_QUERY, maintain_order=True), disable=not ctx.verbose
100+
):
91101
pipe_ctx = PipelineContext(
92102
sub=sub_ses_group["sub"][0],
93103
ses=sub_ses_group["ses"][0] or None,
@@ -96,29 +106,41 @@ def main(args: AllArgs) -> int: # noqa: C901
96106
session = load_session(sub_ses_group, pipe_ctx.sub, pipe_ctx.ses)
97107

98108
# --- Anatomical (once per session, first T1w) ---
99-
anat_row = session.anat.row(0, named=True)
100-
t1w_fpath = Path(anat_row["root"]) / anat_row["path"]
101-
ctx.logger.info(f"Anatomical: {t1w_fpath}")
102-
103-
anat_outputs = anatomical_preprocess(in_t1w=t1w_fpath)
104-
105-
anat = pipe_ctx.bids(datatype=Datatype.ANAT)
106-
anat.save(anat_outputs.brain, suffix=Suffix.T1W, desc="brain")
107-
anat.save(anat_outputs.brain_mask, suffix=Suffix.MASK, desc="T1w")
108-
anat.save(anat_outputs.csf_mask, suffix=Suffix.MASK, desc="csf")
109-
anat.save(anat_outputs.gm_mask, suffix=Suffix.MASK, desc="gm")
110-
anat.save(anat_outputs.wm_mask, suffix=Suffix.MASK, desc="wm")
111-
anat.save(anat_outputs.wm_bbr_mask, suffix=Suffix.MASK, desc="wmBBR")
112-
anat.save(
113-
anat_outputs.forward_xfm,
114-
suffix="xfm",
115-
extra={"from": "T1w", "to": TemplateSpace.MNI152NLIN6ASYM, "mode": "image"},
116-
)
117-
anat.save(
118-
anat_outputs.inverse_xfm,
119-
suffix="xfm",
120-
extra={"from": TemplateSpace.MNI152NLIN6ASYM, "to": "T1w", "mode": "image"},
121-
)
109+
for _, anat_df in session.anat.filter(pl.col("suffix") == "T1w").group_by(
110+
_ANAT_GROUP_ENTITIES, maintain_order=True
111+
):
112+
anat_row = anat_df.filter(suffix="T1w").row(0, named=True)
113+
t1w_fpath = Path(anat_row["root"]) / anat_row["path"]
114+
ents = extract_entities(anat_row, ["run", "acq", "rec", "echo"])
115+
ctx.logger.info(f"Anatomical: {t1w_fpath}")
116+
117+
anat_outputs = anatomical_preprocess(in_t1w=t1w_fpath)
118+
119+
anat = pipe_ctx.bids(datatype=Datatype.ANAT, entities=ents)
120+
anat.save(anat_outputs.brain, suffix=Suffix.T1W, desc="brain")
121+
anat.save(anat_outputs.brain_mask, suffix=Suffix.MASK, desc="T1w")
122+
anat.save(anat_outputs.csf_mask, suffix=Suffix.MASK, desc="csf")
123+
anat.save(anat_outputs.gm_mask, suffix=Suffix.MASK, desc="gm")
124+
anat.save(anat_outputs.wm_mask, suffix=Suffix.MASK, desc="wm")
125+
anat.save(anat_outputs.wm_bbr_mask, suffix=Suffix.MASK, desc="wmBBR")
126+
anat.save(
127+
anat_outputs.forward_xfm,
128+
suffix="xfm",
129+
extra={
130+
"from": "T1w",
131+
"to": TemplateSpace.MNI152NLIN6ASYM,
132+
"mode": "image",
133+
},
134+
)
135+
anat.save(
136+
anat_outputs.inverse_xfm,
137+
suffix="xfm",
138+
extra={
139+
"from": TemplateSpace.MNI152NLIN6ASYM,
140+
"to": "T1w",
141+
"mode": "image",
142+
},
143+
)
122144

123145
# --- Functional + Metrics + QC (per BOLD run) ---
124146
for func_df, _anat_df in iter_session_files(

src/rbc/cli/anatomical.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66
from typing import TYPE_CHECKING
77

8-
from rbc.cli.query import iter_session_files, load_session
8+
from rbc.cli.query import load_session
99

1010
if TYPE_CHECKING:
1111
import argparse
@@ -46,33 +46,37 @@ def main(args: AnatomicalArgs) -> int:
4646
dataset_dir=args.input_dir, index_fpath=None, max_workers=0, verbose=ctx.verbose
4747
)
4848

49-
filters = [pl.col("space").is_null(), pl.col("desc").is_null()]
49+
filters = [
50+
pl.col("ses") != "longitudinal",
51+
pl.col("space").is_null(),
52+
pl.col("desc").is_null(),
53+
]
5054
if len(args.participant_label) > 0:
5155
filters.append(pl.col("sub").is_in(args.participant_label))
5256
if len(args.session_label) > 0:
5357
filters.append(pl.col("ses").is_in(args.session_label))
54-
if filters:
55-
df = df.filter(pl.all_horizontal(filters))
58+
df = df.filter(pl.all_horizontal(filters))
5659

57-
for _, sub_ses_group in tqdm(df.group_by(_SUB_SES_QUERY), disable=not ctx.verbose):
60+
for _, sub_ses_group in tqdm(
61+
df.group_by(_SUB_SES_QUERY, maintain_order=True), disable=not ctx.verbose
62+
):
5863
pipe_ctx = PipelineContext(
5964
sub=sub_ses_group["sub"][0],
6065
ses=sub_ses_group["ses"][0] or None,
6166
output_dir=args.output_dir,
6267
)
6368
session = load_session(sub_ses_group, pipe_ctx.sub, pipe_ctx.ses)
6469

65-
for _, anat_df in iter_session_files(session, groupby=_ANAT_GROUP_ENTITIES):
66-
row = anat_df.filter(suffix="T1w").row(0, named=True)
70+
for _, anat_df in session.anat.filter(pl.col("suffix") == "T1w").group_by(
71+
_ANAT_GROUP_ENTITIES, maintain_order=True
72+
):
73+
row = anat_df.row(0, named=True)
6774
t1w_fpath = Path(row["root"]) / row["path"]
6875
ents = extract_entities(row, ["run", "acq", "rec", "echo"])
6976
ctx.logger.info(f"Processing {t1w_fpath}")
7077

7178
outputs = single_session_preprocess(in_t1w=t1w_fpath)
7279

73-
pipe_ctx = PipelineContext(
74-
sub=row["sub"], ses=row.get("ses"), output_dir=args.output_dir
75-
)
7680
anat = pipe_ctx.bids(datatype=Datatype.ANAT, entities=ents)
7781
anat.save(outputs.brain, suffix=Suffix.T1W, desc="brain")
7882
anat.save(outputs.brain_mask, suffix=Suffix.MASK, desc="T1w")

src/rbc/cli/functional.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,30 @@ def main(args: FunctionalArgs) -> int:
6565
dataset_dir=args.input_dir, index_fpath=None, max_workers=0, verbose=ctx.verbose
6666
)
6767

68-
filters = [pl.col("space").is_null(), pl.col("desc").is_null()]
68+
filters = [pl.col("ses") != "longitudinal", pl.col("space").is_null()]
6969
if len(args.participant_label) > 0:
7070
filters.append(pl.col("sub").is_in(args.participant_label))
7171
if len(args.session_label) > 0:
7272
filters.append(pl.col("ses").is_in(args.session_label))
7373
if args.task is not None:
7474
filters.append(pl.col("task") == args.task)
75-
if filters:
76-
df = df.filter(pl.all_horizontal(filters))
75+
df = df.filter(pl.all_horizontal(filters))
7776

78-
for _, sub_ses_group in tqdm(df.group_by(_SUB_SES_QUERY), disable=not ctx.verbose):
77+
for _, sub_ses_group in tqdm(
78+
df.group_by(_SUB_SES_QUERY, maintain_order=True), disable=not ctx.verbose
79+
):
7980
pipe_ctx = PipelineContext(
8081
sub=sub_ses_group["sub"][0],
8182
ses=sub_ses_group["ses"][0] or None,
8283
output_dir=args.output_dir,
8384
)
85+
8486
session = load_session(sub_ses_group, pipe_ctx.sub, pipe_ctx.ses)
8587

8688
for func_df, anat_df in iter_session_files(
8789
session, groupby=_FUNC_GROUP_ENTITIES
8890
):
91+
func_df = func_df.filter(pl.col("desc").is_null())
8992
row = func_df.filter(suffix="bold").row(0, named=True)
9093
bold_fpath = Path(row["root"]) / row["path"]
9194
ents = extract_entities(row, ["task", "run", "acq", "rec", "dir", "echo"])

src/rbc/cli/longitudinal.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ def _process_anat(
5555
pipe_ctx: PipelineContext, anat_df: pl.DataFrame, tpl_df: pl.DataFrame
5656
) -> None:
5757
"""Handle anatomical longitudinal processing."""
58-
row = anat_df.filter(suffix="T1w").row(0, named=True)
58+
anat_df = anat_df.filter(pl.col("space").is_null())
59+
row = anat_df.row(0, named=True)
5960
ents = extract_entities(row, ["run"])
6061

6162
anat_q = pipe_ctx.bids(datatype=Datatype.ANAT)
62-
tpl_q = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")
63+
tpl_q = anat_q.derive(ses="longitudinal")
6364

6465
outputs = anatomical_longitudinal(
6566
template=tpl_q.expect(tpl_df, suffix=Suffix.T1W),
@@ -76,7 +77,7 @@ def _process_anat(
7677
wm_mask=anat_q.find(anat_df, suffix=Suffix.MASK, desc="wm"),
7778
)
7879

79-
aex = pipe_ctx.bids(datatype=Datatype.ANAT, entities=ents, space="longitudinal")
80+
aex = anat_q.derive(entities=ents, space="longitudinal")
8081
aex.save(outputs.brain, suffix=Suffix.T1W, desc="brain")
8182
aex.save(
8283
_require_file(outputs.brain_mask, "brain_mask"),
@@ -108,7 +109,7 @@ def _process_func(
108109
ents = extract_entities(row, ["task", "run"])
109110

110111
func_q = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents)
111-
tpl_q = pipe_ctx.bids(datatype="anat").derive(ses="longitudinal")
112+
tpl_q = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")
112113

113114
outputs = functional_longitudinal(
114115
template=tpl_q.expect(tpl_df, suffix="T1w"),
@@ -134,7 +135,7 @@ def _process_func(
134135
),
135136
)
136137

137-
fex = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents, space="longitudinal")
138+
fex = func_q.derive(space="longitudinal")
138139
fex.save(outputs.sbref, suffix=Suffix.SBREF)
139140
fex.save(outputs.bold, suffix=Suffix.BOLD, desc="preproc")
140141
fex.save(
@@ -163,20 +164,15 @@ def main(args: LongitudinalArgs) -> int:
163164
)
164165

165166
group_df = df
166-
filters = [
167-
pl.col("ses") != "longitudinal",
168-
pl.col("space").is_null(),
169-
pl.col("desc").is_null(),
170-
]
167+
filters = [pl.col("ses") != "longitudinal"]
171168
if len(args.participant_label) > 0:
172169
filters.append(pl.col("sub").is_in(args.participant_label))
173170
if len(args.session_label) > 0:
174171
filters.append(pl.col("ses").is_in(args.session_label))
175-
if filters:
176-
group_df = df.filter(pl.all_horizontal(filters))
172+
group_df = df.filter(pl.all_horizontal(filters))
177173

178174
for _, sub_ses_group in tqdm(
179-
group_df.group_by(_SUB_SES_QUERY), disable=not ctx.verbose
175+
group_df.group_by(_SUB_SES_QUERY, maintain_order=True), disable=not ctx.verbose
180176
):
181177
pipe_ctx = PipelineContext(
182178
sub=sub_ses_group["sub"][0],
@@ -196,12 +192,14 @@ def main(args: LongitudinalArgs) -> int:
196192
if tpl_df.is_empty():
197193
raise ValueError("No longitudinal template found")
198194

199-
for func_df, anat_df in iter_session_files(
200-
session, groupby=_FUNC_GROUP_ENTITIES
201-
):
202-
if args.anatomical:
195+
if args.anatomical:
196+
for _, anat_df in session.anat.filter(pl.col("suffix") == "T1w").group_by(
197+
("run", "acq"), maintain_order=True
198+
):
203199
_process_anat(pipe_ctx=pipe_ctx, anat_df=anat_df, tpl_df=tpl_df)
204-
if args.functional:
200+
201+
if args.functional:
202+
for func_df, _ in iter_session_files(session, groupby=_FUNC_GROUP_ENTITIES):
205203
_process_func(pipe_ctx=pipe_ctx, func_df=func_df, tpl_df=tpl_df)
206204
pipe_ctx.ensure_dataset_description()
207205

tests/unit/cli/test_all.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _mock_qc_outputs(
9898
@contextmanager
9999
def _patch_all(
100100
filtered_df: pl.DataFrame,
101-
groups: list[list[str]],
101+
groups: list[list[tuple[pl.DataFrame, pl.DataFrame]]],
102102
*,
103103
qc_passed: bool = True,
104104
) -> Generator[tuple[Mock, Mock, Mock, Mock, Mock], None, None]:
@@ -113,17 +113,24 @@ def _patch_all(
113113
"suffix": ["T1w"],
114114
"ext": [".nii.gz"],
115115
"run": [None],
116+
"acq": [None],
117+
"part": [None],
118+
"echo": [None],
119+
"ce": [None],
120+
"rec": [None],
121+
"inv": [None],
116122
"space": [None],
117123
"desc": [None],
118124
"root": ["/data"],
119125
"path": ["sub-01/ses-baseline/anat/sub-01_ses-baseline_T1w.nii.gz"],
120126
}
121127
)
122128
mock_session = SessionTables(anat=mock_anat_df, func=None)
129+
iter_calls = list(groups)
123130
with (
124131
patch("rbc.cli.all.load_table", return_value=filtered_df),
125132
patch("rbc.cli.all.load_session", return_value=mock_session),
126-
patch("rbc.cli.all.iter_session_files", side_effect=groups),
133+
patch("rbc.cli.all.iter_session_files", side_effect=iter_calls),
127134
patch(
128135
"rbc.cli.all.anatomical_preprocess", return_value=_mock_anat_outputs()
129136
) as mock_anat,
@@ -147,7 +154,7 @@ def _make_groups(
147154
participant: list[str],
148155
session: list[str],
149156
task: str | None = None,
150-
) -> tuple[pl.DataFrame, list[list[str]]]:
157+
) -> tuple[pl.DataFrame, list[list[tuple[pl.DataFrame, pl.DataFrame]]]]:
151158
"""Filter sample dataframe and build iter_session_files groups."""
152159
filtered_df = sample_dataframe.filter(
153160
pl.col("suffix") == "bold",
@@ -165,7 +172,9 @@ def _make_groups(
165172
)
166173
key = (row["sub"], row["ses"])
167174
sub_ses_groups.setdefault(key, [])
168-
sub_ses_groups[key].append((func_group, pl.DataFrame()))
175+
sub_ses_groups[key].append(
176+
(func_group, pl.DataFrame({"space": [], "desc": []}))
177+
)
169178

170179
full_df = _make_filtered_df(sample_dataframe, participant, session, task)
171180
return full_df, list(sub_ses_groups.values())

0 commit comments

Comments
 (0)