-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsession.py
More file actions
167 lines (125 loc) · 5.41 KB
/
session.py
File metadata and controls
167 lines (125 loc) · 5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""BIDS session-level discovery and iteration.
Provides utilities for loading subject/session data from a BIDS dataset
and iterating over run/task groups with matched anatomical files.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
import polars as pl
from rbc.bids import Datatype, extract_entities
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from rbc.bids import EntityKwargs
SUB_SES_QUERY = ("sub", "ses")
ANAT_GROUP_ENTITIES = ("run", "acq", "suffix", "part", "echo", "ce", "rec", "inv")
"""BIDS entities used to group anatomical files in dataframes generated by
Bids2Table; includes ``suffix`` to distinguish T1w, T2w, etc."""
FUNC_GROUP_ENTITIES = ("task", "run", "acq", "dir", "echo", "part", "rec")
"""BIDS entities used to group functional files in dataframes generated by
Bids2Table; excludes ``suffix`` so 'sbref' stays grouped with 'bold'."""
class SessionTables(NamedTuple):
"""Subject-Session combination tables for anatomical and functional."""
anat: pl.DataFrame
func: pl.DataFrame | None
def load_session(df: pl.DataFrame, subject: str, session: str | None) -> SessionTables:
"""Filter for anatomical and functional data for a single subject/session.
Args:
df: Full bids2table dataframe.
subject: Subject label without 'sub-' prefix (e.g. ``'01'``)
session: Session label without 'ses-' prefix (e.g. ``'02'``)
Returns:
A :class:`SessionTables` containing separate anatomical and functional
dataframes.
"""
base: list[pl.Expr] = [pl.col("sub") == subject]
if session is not None:
base.append(pl.col("ses") == session)
anat_df = df.filter(pl.all_horizontal([*base, pl.col("datatype") == Datatype.ANAT]))
func_df = df.filter(pl.all_horizontal([*base, pl.col("datatype") == Datatype.FUNC]))
return SessionTables(anat=anat_df, func=func_df if not func_df.is_empty() else None)
def _resolve_anat(
primary_group: pl.DataFrame,
anat: pl.DataFrame,
fallback_anat: pl.DataFrame,
*,
runs_correspond: bool,
) -> pl.DataFrame:
"""Resolve the anat subset for a given primary group."""
if runs_correspond:
run_vals = primary_group["run"].drop_nulls().unique().implode()
matched = anat.filter(pl.col("run").is_in(run_vals))
return matched if not matched.is_empty() else fallback_anat
return fallback_anat
def iter_session_files(
session: SessionTables,
groupby: Sequence[str] = ("run"),
) -> Iterator[tuple[pl.DataFrame, pl.DataFrame]]:
"""Iterate over run/task combos, paired with matching anat files.
When functional data is present it drives iteration. For a pure anatomical
pipeline (``session.func is None``), iteration is driven by the anat groups
instead and each yield is ``(anat_group, anat_group)``.
Anat matching follows this precedence:
1. **1-to-1**: anat and func have the same number of runs, match by run label.
2. **1-to-many**: run counts differ, use the anat for the first run.
3. **No runs**: no run labels on either side, use available anat (e.g. single T1w).
Args:
session: A :class:`SessionTables` for a single subject/session.
groupby: Sequence of BIDS entities to group the primary dataframe by.
Yields:
``(primary_group, anat_subset)`` tuples. For functional pipelines
``primary_group`` is a func group; for anat-only pipelines both values
are the same anat group.
"""
has_anat_runs = (
"run" in session.anat.columns and session.anat["run"].drop_nulls().len() > 0
)
anat_runs = (
session.anat["run"].drop_nulls().unique()
if has_anat_runs
else pl.Series([], dtype=pl.Utf8)
)
if has_anat_runs:
first_run = anat_runs.sort()[0]
fallback_anat = session.anat.filter(pl.col("run") == first_run)
else:
fallback_anat = session.anat
# Iteration from anat directly
if session.func is None:
for _, group in session.anat.group_by(groupby):
yield group, group
return
func_runs = (
session.func["run"].drop_nulls().unique()
if "run" in session.func.columns
else pl.Series([], dtype=pl.Utf8)
)
for _, func_group in session.func.group_by(groupby):
anat_subset = _resolve_anat(
func_group,
session.anat,
fallback_anat,
runs_correspond=has_anat_runs and len(anat_runs) == len(func_runs),
)
yield func_group, anat_subset
_FUNC_ENTITY_KEYS = ("task", "run", "acq", "rec", "dir", "echo")
class DerivativeRun(NamedTuple):
"""A single functional run discovered from derivative data.
Attributes:
entities: BIDS entities for this run (task, run, acq, rec, dir, echo).
"""
entities: EntityKwargs
def discover_derivative_runs(
group: pl.DataFrame,
) -> Iterator[DerivativeRun]:
"""Discover functional runs within a sub/ses derivative group.
Groups by :data:`FUNC_GROUP_ENTITIES` and extracts standard functional
entities from each group.
Args:
group: DataFrame of derivative BOLD runs for a single sub/ses.
Yields:
A :class:`DerivativeRun` for each functional run group.
"""
for _, run_group in group.group_by(FUNC_GROUP_ENTITIES):
row = run_group.row(0, named=True)
yield DerivativeRun(
entities=extract_entities(row, list(_FUNC_ENTITY_KEYS)),
)