Skip to content

Commit 35ad7b1

Browse files
committed
Extract row entities to find refs + file naming
1 parent 6dc432d commit 35ad7b1

6 files changed

Lines changed: 49 additions & 16 deletions

File tree

src/rbc/bids/longitudinal/template.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import polars as pl
99

10-
from rbc.bids import Suffix, bids_safe_label
10+
from rbc.bids import FUNC_GROUP_ENTITIES, Suffix, bids_safe_label
1111

1212
if TYPE_CHECKING:
1313
from rbc.bids import Bids
@@ -77,23 +77,19 @@ def discover_template_inputs(
7777
files = [
7878
Path(row["root"]) / row["path"] for row in sub_group.iter_rows(named=True)
7979
]
80-
# Filter for first found session; only single reference per task is needed
8180
sub_bold = bold_rows.filter(
8281
(pl.col("sub") == sub) & (pl.col("ses") == sessions[0])
83-
).unique(subset=("task", "root", "path"))
82+
).unique(subset=(*FUNC_GROUP_ENTITIES, "root", "path"))
8483
# Check each task is unique, otherwise raise assertion error with details
85-
if sub_bold.height != sub_bold["task"].n_unique():
84+
if sub_bold.height != sub_bold.unique().height:
8685
conflicts = (
87-
sub_bold.filter(pl.col("task").is_duplicated())
88-
.group_by("task")
86+
sub_bold.filter(pl.struct(FUNC_GROUP_ENTITIES).is_duplicated())
87+
.group_by(FUNC_GROUP_ENTITIES)
8988
.agg(pl.format("{}/{}", "root", "path").alias("paths"))
9089
)
9190
raise AssertionError(
9291
f"Found multiple non-matching grids for subject {sub}:\n"
93-
+ "\n".join(
94-
f"Task '{row['task']}': {row['paths']}"
95-
for row in conflicts.iter_rows(named=True)
96-
)
92+
+ "\n".join(str(dict(row)) for row in conflicts.iter_rows(named=True))
9793
)
9894
bold_files = {
9995
row["task"]: Path(row["root"]) / row["path"]

src/rbc/orchestration/longitudinal/all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rbc.bids import FUNC_GROUP_ENTITIES, Datatype, Suffix, extract_entities, load_table
1919
from rbc.bids.longitudinal.template import discover_template_inputs
2020
from rbc.bids.metrics import export_metrics
21-
from rbc.bids.session import iter_session_files
21+
from rbc.bids.session import _FUNC_ENTITY_KEYS, iter_session_files
2222
from rbc.context import RunContext
2323
from rbc.orchestration import Filters, RunnerConfig, init_runner
2424
from rbc.orchestration.longitudinal._iter import iter_sessions_with_template
@@ -124,7 +124,7 @@ def run(
124124
)
125125

126126
row = func_df.filter(suffix=Suffix.BOLD).row(0, named=True)
127-
ents = extract_entities(row, ["task", "run"])
127+
ents = extract_entities(row, _FUNC_ENTITY_KEYS)
128128
func_q = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents)
129129
func_long = func_q.derive(space="longitudinal")
130130

src/rbc/orchestration/longitudinal/anatomical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def process_anat(
5151
Workflow outputs for in-memory handoff to downstream stages.
5252
"""
5353
anat_df = anat_df.filter(pl.col("space").is_null())
54-
ents = extract_entities(anat_df.row(0, named=True), ["run"])
54+
ents = extract_entities(anat_df.row(0, named=True), ["run", "acq", "rec", "echo"])
5555

5656
anat_q = pipe_ctx.bids(datatype=Datatype.ANAT)
5757
tpl_q = anat_q.derive(ses="longitudinal")

src/rbc/orchestration/longitudinal/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
export_longitudinal_func,
1111
resolve_longitudinal_func,
1212
)
13-
from rbc.bids.session import iter_session_files
13+
from rbc.bids.session import _FUNC_ENTITY_KEYS, iter_session_files
1414
from rbc.orchestration import Filters, RunnerConfig, init_runner
1515
from rbc.orchestration.longitudinal._iter import iter_sessions_with_template
1616
from rbc.workflows.longitudinal.functional import (
@@ -53,7 +53,7 @@ def process_func(
5353
Workflow outputs for in-memory handoff to downstream stages.
5454
"""
5555
row = func_df.filter(suffix=Suffix.BOLD).row(0, named=True)
56-
ents = extract_entities(row, ["task", "run"])
56+
ents = extract_entities(row, list(_FUNC_ENTITY_KEYS))
5757

5858
func_q = pipe_ctx.bids(datatype=Datatype.FUNC, entities=ents)
5959
tpl_q = pipe_ctx.bids(datatype=Datatype.ANAT).derive(ses="longitudinal")

tests/unit/bids/test_longitudinal_template.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
"sub",
2424
"ses",
2525
"space",
26+
"acq",
27+
"dir",
28+
"echo",
29+
"part",
30+
"rec",
2631
"task",
2732
"run",
2833
"desc",
@@ -50,6 +55,11 @@ def _anat_row(sub: str, ses: str, space: str | None = None) -> tuple:
5055
space,
5156
None,
5257
None,
58+
None,
59+
None,
60+
None,
61+
None,
62+
None,
5363
"brain",
5464
"/data",
5565
path,
@@ -58,7 +68,24 @@ def _anat_row(sub: str, ses: str, space: str | None = None) -> tuple:
5868

5969
def _func_row(sub: str, ses: str, task: str = "rest") -> tuple:
6070
path = f"sub-{sub}/ses-{ses}/func/sub-{sub}_ses-{ses}_task-{task}_bold.nii.gz"
61-
return ("func", "bold", ".nii.gz", sub, ses, None, task, None, None, "/data", path)
71+
return (
72+
"func",
73+
"bold",
74+
".nii.gz",
75+
sub,
76+
ses,
77+
None,
78+
None,
79+
None,
80+
None,
81+
None,
82+
None,
83+
task,
84+
None,
85+
None,
86+
"/data",
87+
path,
88+
)
6289

6390

6491
class TestDiscoverTemplateInputs:

tests/unit/orchestration/test_longitudinal_template.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
"sub",
2424
"ses",
2525
"space",
26+
"acq",
27+
"dir",
28+
"echo",
29+
"part",
30+
"rec",
2631
"task",
2732
"run",
2833
"desc",
@@ -42,6 +47,11 @@ def _brain_row(sub: str, ses: str) -> tuple:
4247
None,
4348
None,
4449
None,
50+
None,
51+
None,
52+
None,
53+
None,
54+
None,
4555
"brain",
4656
"/data",
4757
path,

0 commit comments

Comments
 (0)