Skip to content

Commit 9ce335e

Browse files
committed
♻️ Don't fork motion reference inputs
1 parent c80cc72 commit 9ce335e

File tree

6 files changed

+121
-79
lines changed

6 files changed

+121
-79
lines changed

CPAC/func_preproc/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2012-2023 C-PAC Developers
1+
# Copyright (C) 2012-2025 C-PAC Developers
22

33
# This file is part of C-PAC.
44

@@ -22,17 +22,20 @@
2222
func_motion_correct_only,
2323
func_motion_estimates,
2424
get_motion_ref,
25+
get_motion_ref_fmriprep,
2526
motion_estimate_filter,
2627
)
2728
from .func_preproc import get_idx, slice_timing_wf
2829

30+
get_motion_refs = [get_motion_ref, get_motion_ref_fmriprep]
31+
2932
__all__ = [
3033
"calc_motion_stats",
3134
"func_motion_correct",
3235
"func_motion_correct_only",
3336
"func_motion_estimates",
3437
"get_idx",
35-
"get_motion_ref",
38+
"get_motion_refs",
3639
"motion_estimate_filter",
3740
"slice_timing_wf",
3841
]

CPAC/func_preproc/func_motion.py

Lines changed: 90 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2012-2024 C-PAC Developers
1+
# Copyright (C) 2012-2025 C-PAC Developers
22

33
# This file is part of C-PAC.
44

@@ -16,7 +16,8 @@
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Functions for calculating motion parameters."""
1818

19-
# pylint: disable=ungrouped-imports,wrong-import-order,wrong-import-position
19+
from typing import Literal
20+
2021
from nipype.interfaces import afni, fsl, utility as util
2122
from nipype.interfaces.afni import preprocess, utils as afni_utils
2223

@@ -31,8 +32,10 @@
3132
motion_power_statistics,
3233
)
3334
from CPAC.pipeline import nipype_pipeline_engine as pe
35+
from CPAC.pipeline.engine import ResourcePool
3436
from CPAC.pipeline.nodeblock import nodeblock
3537
from CPAC.pipeline.schema import valid_options
38+
from CPAC.utils.configuration import Configuration
3639
from CPAC.utils.interfaces.function import Function
3740
from CPAC.utils.utils import check_prov_for_motion_tool
3841

@@ -364,79 +367,97 @@ def get_mcflirt_rms_abs(rms_files):
364367
"motion_correction",
365368
"motion_correction_reference",
366369
],
367-
option_val=["mean", "median", "selected_volume", "fmriprep_reference"],
368-
inputs=["desc-preproc_bold", "desc-reorient_bold"],
370+
option_val=["mean", "median", "selected_volume"],
371+
inputs=["desc-preproc_bold"],
369372
outputs=["motion-basefile"],
370373
)
371-
def get_motion_ref(wf, cfg, strat_pool, pipe_num, opt=None):
372-
if opt not in get_motion_ref.option_val:
373-
msg = (
374-
"\n\n[!] Error: The 'motion_correction_reference' "
375-
"parameter of the 'motion_correction' workflow "
376-
"must be one of:\n\t{0}.\n\nTool input: '{1}'"
377-
"\n\n".format(
378-
" or ".join([f"'{val}'" for val in get_motion_ref.option_val]), opt
374+
def get_motion_ref(
375+
wf: pe.Workflow,
376+
cfg: Configuration,
377+
strat_pool: ResourcePool,
378+
pipe_num: int,
379+
opt: Literal["mean", "median", "selected_volume"],
380+
) -> tuple[pe.Workflow, dict[str, tuple[pe.Node, str]]]:
381+
"""Get the reference image for motion correction."""
382+
node, out = strat_pool.get_data("desc-preproc_bold")
383+
in_label = "in_file"
384+
match opt:
385+
case "mean":
386+
func_get_RPI = pe.Node(
387+
interface=afni_utils.TStat(options="-mean"),
388+
name=f"func_get_mean_RPI_{pipe_num}",
389+
mem_gb=0.48,
390+
mem_x=(1435097126797993 / 302231454903657293676544, in_label),
379391
)
380-
)
381-
raise ValueError(msg)
382-
383-
if opt == "mean":
384-
func_get_RPI = pe.Node(
385-
interface=afni_utils.TStat(),
386-
name=f"func_get_mean_RPI_{pipe_num}",
387-
mem_gb=0.48,
388-
mem_x=(1435097126797993 / 302231454903657293676544, "in_file"),
389-
)
390-
391-
func_get_RPI.inputs.options = "-mean"
392-
func_get_RPI.inputs.outputtype = "NIFTI_GZ"
393-
394-
node, out = strat_pool.get_data("desc-preproc_bold")
395-
wf.connect(node, out, func_get_RPI, "in_file")
396-
397-
elif opt == "median":
398-
func_get_RPI = pe.Node(
399-
interface=afni_utils.TStat(), name=f"func_get_median_RPI_{pipe_num}"
400-
)
401-
402-
func_get_RPI.inputs.options = "-median"
403-
func_get_RPI.inputs.outputtype = "NIFTI_GZ"
404-
405-
node, out = strat_pool.get_data("desc-preproc_bold")
406-
wf.connect(node, out, func_get_RPI, "in_file")
407-
408-
elif opt == "selected_volume":
409-
func_get_RPI = pe.Node(
410-
interface=afni.Calc(), name=f"func_get_selected_RPI_{pipe_num}"
411-
)
412-
413-
func_get_RPI.inputs.set(
414-
expr="a",
415-
single_idx=cfg.functional_preproc["motion_estimates_and_correction"][
416-
"motion_correction"
417-
]["motion_correction_reference_volume"],
418-
outputtype="NIFTI_GZ",
419-
)
392+
case "median":
393+
func_get_RPI = pe.Node(
394+
interface=afni_utils.TStat(options="-median"),
395+
name=f"func_get_median_RPI_{pipe_num}",
396+
)
397+
case "selected_volume":
398+
func_get_RPI = pe.Node(
399+
interface=afni.Calc(
400+
expr="a",
401+
single_idx=cfg.functional_preproc[
402+
"motion_estimates_and_correction"
403+
]["motion_correction"]["motion_correction_reference_volume"],
404+
),
405+
name=f"func_get_selected_RPI_{pipe_num}",
406+
)
407+
in_label = "in_file_a"
408+
case _:
409+
msg = (
410+
"\n\n[!] Error: The 'motion_correction_reference' "
411+
"parameter of the 'motion_correction' workflow "
412+
"must be one of:\n\t{0}.\n\nTool input: '{1}'"
413+
"\n\n".format(
414+
" or ".join([f"'{val}'" for val in get_motion_ref.option_val]), opt
415+
)
416+
)
417+
raise ValueError(msg)
418+
func_get_RPI.inputs.outputtype = "NIFTI_GZ"
419+
wf.connect(node, out, func_get_RPI, in_label)
420+
outputs = {"motion-basefile": (func_get_RPI, "out_file")}
421+
return wf, outputs
420422

421-
node, out = strat_pool.get_data("desc-preproc_bold")
422-
wf.connect(node, out, func_get_RPI, "in_file_a")
423423

424-
elif opt == "fmriprep_reference":
425-
func_get_RPI = pe.Node(
426-
Function(
427-
input_names=["in_file"],
428-
output_names=["out_file"],
429-
function=estimate_reference_image,
430-
),
431-
name=f"func_get_fmriprep_ref_{pipe_num}",
432-
)
424+
@nodeblock(
425+
name="get_motion_ref_fmriprep",
426+
switch=["functional_preproc", "motion_estimates_and_correction", "run"],
427+
option_key=[
428+
"functional_preproc",
429+
"motion_estimates_and_correction",
430+
"motion_correction",
431+
"motion_correction_reference",
432+
],
433+
option_val=["fmriprep_reference"],
434+
inputs=["desc-reorient_bold"],
435+
outputs=["motion-basefile"],
436+
)
437+
def get_motion_ref_fmriprep(
438+
wf: pe.Workflow,
439+
cfg: Configuration,
440+
strat_pool: ResourcePool,
441+
pipe_num: int,
442+
opt: Literal["fmriprep_reference"],
443+
) -> tuple[pe.Workflow, dict[str, tuple[pe.Node, str]]]:
444+
"""Get the fMRIPrep-style reference image for motion correction."""
445+
assert opt == "fmriprep_reference"
446+
func_get_RPI = pe.Node(
447+
Function(
448+
input_names=["in_file"],
449+
output_names=["out_file"],
450+
function=estimate_reference_image,
451+
),
452+
name=f"func_get_fmriprep_ref_{pipe_num}",
453+
)
433454

434-
node, out = strat_pool.get_data("desc-reorient_bold")
435-
wf.connect(node, out, func_get_RPI, "in_file")
455+
node, out = strat_pool.get_data("desc-reorient_bold")
456+
wf.connect(node, out, func_get_RPI, "in_file")
436457

437458
outputs = {"motion-basefile": (func_get_RPI, "out_file")}
438459

439-
return (wf, outputs)
460+
return wf, outputs
440461

441462

442463
def motion_correct_3dvolreg(wf, cfg, strat_pool, pipe_num):
@@ -728,7 +749,9 @@ def motion_correct_mcflirt(wf, cfg, strat_pool, pipe_num):
728749
}
729750

730751

731-
def motion_correct_connections(wf, cfg, strat_pool, pipe_num, opt):
752+
def motion_correct_connections(
753+
wf, cfg, strat_pool, pipe_num, opt
754+
): # -> tuple[Any, dict[str, tuple[Node, str]]]:
732755
"""Check opt for valid option, then connect that option."""
733756
motion_correct_options = valid_options["motion_correction"]
734757
if opt not in motion_correct_options:

CPAC/func_preproc/tests/test_preproc_connections.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
func_motion_correct,
3232
func_motion_correct_only,
3333
func_motion_estimates,
34-
get_motion_ref,
34+
get_motion_refs,
3535
motion_estimate_filter,
3636
)
3737
from CPAC.func_preproc.func_preproc import func_normalize
@@ -219,13 +219,13 @@ def test_motion_filter_connections(
219219
"calculate_motion_first",
220220
]:
221221
func_motion_blocks = [
222-
get_motion_ref,
222+
*get_motion_refs,
223223
func_motion_estimates,
224224
motion_estimate_filter,
225225
]
226226
else:
227227
func_motion_blocks = [
228-
get_motion_ref,
228+
*get_motion_refs,
229229
func_motion_correct,
230230
motion_estimate_filter,
231231
]

CPAC/pipeline/cpac_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2012-2024 C-PAC Developers
1+
# Copyright (C) 2012-2025 C-PAC Developers
22

33
# This file is part of C-PAC.
44

@@ -87,7 +87,7 @@
8787
func_motion_correct,
8888
func_motion_correct_only,
8989
func_motion_estimates,
90-
get_motion_ref,
90+
get_motion_refs,
9191
motion_estimate_filter,
9292
)
9393
from CPAC.func_preproc.func_preproc import (
@@ -1318,7 +1318,7 @@ def build_workflow(subject_id, sub_dict, cfg, pipeline_name=None):
13181318
"motion_estimates"
13191319
]["calculate_motion_first"]:
13201320
func_motion_blocks = [
1321-
get_motion_ref,
1321+
*get_motion_refs,
13221322
func_motion_estimates,
13231323
motion_estimate_filter,
13241324
]
@@ -1332,7 +1332,7 @@ def build_workflow(subject_id, sub_dict, cfg, pipeline_name=None):
13321332
)
13331333
else:
13341334
func_motion_blocks = [
1335-
get_motion_ref,
1335+
*get_motion_refs,
13361336
func_motion_correct,
13371337
motion_estimate_filter,
13381338
]

CPAC/utils/monitoring/monitoring.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
# You should have received a copy of the GNU Lesser General Public
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
17-
# pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order
1817
"""Monitoring utilities for C-PAC."""
1918

2019
from datetime import datetime, timedelta

CPAC/utils/tests/test_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Copyright (C) 2018-2025 C-PAC Developers
2+
3+
# This file is part of C-PAC.
4+
5+
# C-PAC is free software: you can redistribute it and/or modify it under
6+
# the terms of the GNU Lesser General Public License as published by the
7+
# Free Software Foundation, either version 3 of the License, or (at your
8+
# option) any later version.
9+
10+
# C-PAC is distributed in the hope that it will be useful, but WITHOUT
11+
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12+
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
13+
# License for more details.
14+
15+
# You should have received a copy of the GNU Lesser General Public
16+
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
117
"""Tests of CPAC utility functions."""
218

319
from datetime import datetime, timedelta
@@ -7,7 +23,7 @@
723
from _pytest.logging import LogCaptureFixture
824
import pytest
925

10-
from CPAC.func_preproc import get_motion_ref
26+
from CPAC.func_preproc import get_motion_refs
1127
from CPAC.pipeline.nodeblock import NodeBlockFunction
1228
from CPAC.utils.configuration import Configuration
1329
from CPAC.utils.monitoring.custom_logging import log_subprocess
@@ -153,7 +169,8 @@ def test_executable(executable):
153169
_installation_check(executable, "-help")
154170

155171

156-
def test_NodeBlock_option_SSOT(): # pylint: disable=invalid-name
172+
@pytest.mark.parametrize("get_motion_ref", get_motion_refs)
173+
def test_NodeBlock_option_SSOT(get_motion_ref: NodeBlockFunction): # pylint: disable=invalid-name
157174
"""Test using NodeBlock dictionaries for SSOT for options."""
158175
assert isinstance(get_motion_ref, NodeBlockFunction)
159176
with pytest.raises(ValueError) as value_error:

0 commit comments

Comments
 (0)