Skip to content

Commit 412f642

Browse files
committed
Add a brand new cluster estimation to find groups of triggers, i.e. takes or runs.
1 parent a2ba8e9 commit 412f642

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

phys2bids/cli/run.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,31 @@ def _get_parser():
134134
"or just one TR if it is consistent throughout the session.",
135135
default=None,
136136
)
137+
optional.add_argument(
138+
"-esttakes",
139+
"--estimate_takes",
140+
dest="estimate_takes",
141+
action="store_true",
142+
help="Run automatic algorithm to estimate clusters of triggers, i.e. the "
143+
"'takes' or 'runs' of fMRI. Useful when sequences were stopped and restarted, "
144+
"or when you don't know how many triggers or trs you have in each take. "
145+
"This might work 95% of the time. Default is False.",
146+
default=False,
147+
)
148+
optional.add_argument(
149+
"-ci",
150+
"--confidence-interval",
151+
dest="ci",
152+
# Here always as float, later it will check if the float is an integer instead.
153+
type=float,
154+
help="The Confidence Interval (CI) to use in the estimation of the trigger clusters. "
155+
"The cluster algorithm considers triggers with duration (in samples) within this "
156+
"CI as part of the same group. If CI is an integer, it will consider that amount "
157+
"of triggers. If CI is a float and < 1, it will consider that percentage of the "
158+
"trigger duration. CI cannot be a float > 1. Default is 1. Change to .25 if "
159+
"there is a CMRR DWI sequence or if you are recording sub-triggers.",
160+
default=1,
161+
)
137162
optional.add_argument(
138163
"-thr",
139164
"--threshold",

phys2bids/phys2bids.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from phys2bids import _version, bids, utils, viz
3939
from phys2bids.cli.run import _get_parser
4040
from phys2bids.physio_obj import BlueprintOutput
41-
from phys2bids.slice4phys import slice4phys
41+
from phys2bids.slice4phys import estimate_ntp_and_tr, slice4phys
4242

4343
from . import __version__
4444
from .due import Doi, due
@@ -141,6 +141,8 @@ def phys2bids(
141141
chsel=None,
142142
num_timepoints_expected=None,
143143
tr=None,
144+
estimate_takes=False,
145+
ci=1,
144146
thr=None,
145147
pad=9,
146148
ch_name=[],
@@ -304,6 +306,11 @@ def phys2bids(
304306
LGR.info("Renaming channels with given names")
305307
phys_in.rename_channels(ch_name)
306308

309+
# If requested, run the automatic detection of timepoints and groups
310+
311+
if estimate_takes:
312+
num_timepoints_expected, tr = estimate_ntp_and_tr(phys_in, thr=None, ci=1)
313+
307314
# Checking acquisition type via user's input
308315
if tr is not None and num_timepoints_expected is not None:
309316
# Multi-run acquisition type section

phys2bids/slice4phys.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,97 @@
88
LGR = logging.getLogger(__name__)
99

1010

11+
def estimate_ntp_and_tr(phys_in, thr=None, ci=1):
12+
"""
13+
Find groups of trigger in a spiky signal like the trigger channel signal.
14+
"""
15+
LGR.info('Running automatic clustering of triggers to find timepoints and tr of each "take"')
16+
trigger = phys_in.timeseries[phys_in.trigger_idx]
17+
18+
thr = np.mean(trigger) if thr is None else thr
19+
timepoints = trigger > thr
20+
spikes = np.flatnonzero(np.ediff1d(timepoints.astype(np.int8)) > 0)
21+
interspike_interval = np.diff(spikes)
22+
unique_isi, counts = np.unique(interspike_interval, return_counts=True)
23+
24+
# The following line is for python < 3.12. From 3.12, ci.is_integer() is enough.
25+
if isinstance(ci, int) or isinstance(ci, float) and ci.is_integer():
26+
upper_ci_isi = unique_isi + ci
27+
elif isinstance(ci, float) and ci < 1:
28+
upper_ci_isi = unique_isi * (1 + ci)
29+
elif isinstance(ci, float) and ci > 1:
30+
raise ValueError("Confidence intervals above 1 are not supported.")
31+
else:
32+
raise ValueError("Confidence intervals must be either integers or floats.")
33+
34+
# Loop through the uniques ISI and group them within the specified CI.
35+
# Also compute the average TR of the group.
36+
isi_groups = {}
37+
average_tr = {}
38+
k = 0
39+
current_group = [unique_isi[0]]
40+
41+
for n, i in enumerate(range(1, len(unique_isi))):
42+
if unique_isi[i] <= upper_ci_isi[n]:
43+
current_group.append(unique_isi[i])
44+
else:
45+
isi_groups[k] = current_group
46+
average_tr[k] = np.mean(current_group) / phys_in.freq[0]
47+
k += 1
48+
current_group = [unique_isi[i]]
49+
50+
isi_groups[k] = current_group
51+
average_tr[k] = np.mean(current_group) / phys_in.freq[0]
52+
53+
# Invert the isi_group into value per group
54+
group_by_isi = {isi: group for group, isis in isi_groups.items() for isi in isis}
55+
56+
# Use the found groups to find the number of timepoints and assign the right TR
57+
estimated_ntp = []
58+
estimated_tr = []
59+
60+
i = 0
61+
while i < interspike_interval.size - 1:
62+
current_group = group_by_isi.get(interspike_interval[i])
63+
for n in range(i + 1, interspike_interval.size):
64+
if current_group != group_by_isi.get(interspike_interval[n]):
65+
break
66+
# Repeat one last time outside of for loop
67+
estimated_ntp += [n - i]
68+
estimated_tr += [average_tr[current_group]]
69+
i = n
70+
71+
if len(estimated_ntp) < 1:
72+
raise Exception("This should not happen. Something went very wrong.")
73+
# The algorithm found n groups, the last of which has two timepoints less due to
74+
# diff computations. Each real group of n>1 triggers counts one trigger less but is
75+
# followed by a "fake" group of 1 trigger that is actually the interval to the next
76+
# group. That does not hold if there is a real group of 1 trigger.
77+
# Loop through the estiamtions to fix all that.
78+
ntp = []
79+
tr = []
80+
i = 0
81+
82+
while i < len(estimated_ntp):
83+
if estimated_ntp[i] == 1:
84+
ntp.append(estimated_ntp[i])
85+
tr.append(estimated_tr[i])
86+
i += 1
87+
elif i + 1 < len(estimated_ntp):
88+
ntp.append(estimated_ntp[i] + estimated_ntp[i + 1])
89+
tr.append(estimated_tr[i])
90+
i += 2
91+
else:
92+
ntp.append(estimated_ntp[i] + 2)
93+
tr.append(estimated_tr[i])
94+
i += 1
95+
96+
LGR.info(
97+
f"The automatic clustering found {len(ntp)} groups of triggers long: {ntp} with respective TR: {tr}"
98+
)
99+
return ntp, tr
100+
101+
11102
def find_takes(phys_in, ntp_list, tr_list, thr=None, padding=9):
12103
"""
13104
Find takes slicing index.

0 commit comments

Comments
 (0)