Skip to content

Commit 0cf793f

Browse files
committed
👔 Handle connections from unconnected graphs that run separately
1 parent 4ddb58f commit 0cf793f

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

CPAC/longitudinal_pipeline/longitudinal_workflow.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import shutil
2121
import time
22-
from typing import Optional
22+
from typing import cast, Optional
2323

2424
from CPAC.pipeline.nodeblock import nodeblock
2525

@@ -455,8 +455,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur
455455

456456
# Loop over the sessions to create the input for the longitudinal
457457
# algorithm
458-
strats_dct: dict[str, list[tuple[pe.Node, str]]] = {"desc-brain_T1w": [],
459-
"desc-head_T1w": []}
458+
strats_dct: dict[str, list[tuple[pe.Node, str] | str]] = {"desc-brain_T1w": [],
459+
"desc-head_T1w": []}
460460
for i, session in enumerate(sub_list):
461461

462462
unique_id: str = session['unique_id']
@@ -489,13 +489,16 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur
489489
session_wfs[unique_id] = rpool
490490

491491
rpool.gather_pipes(workflow, config)
492-
for key in strats_dct.keys():
493-
_resource: tuple[pe.Node, str] = rpool.get_data(key)
494-
clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}")
495-
workflow.copy_input_connections(_resource[0], clone)
496-
strats_dct[key].append((clone, _resource[1]))
492+
if dry_run: # build tbe graphs with connections that may be in other graphs
493+
for key in strats_dct.keys():
494+
_resource = cast(tuple[pe.Node, str], rpool.get_data(key))
495+
clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}")
496+
workflow.copy_input_connections(_resource[0], clone)
497+
strats_dct[key].append((clone, _resource[1]))
497498
if not dry_run:
498499
workflow.run()
500+
for key in strats_dct.keys(): # get the outputs from run-nodes
501+
strats_dct[key].append(workflow.get_output_path(key, rpool))
499502

500503
wf = initialize_nipype_wf(config, sub_list[0],
501504
# just grab the first one for the name
@@ -533,8 +536,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur
533536
merge_skulls = pe.Node(Merge(num_sessions), name="merge_skulls")
534537

535538
for i in list(range(0, num_sessions)):
536-
wf.connect(*strats_dct["desc-brain_T1w"][i], merge_brains, f"in{i + 1}")
537-
wf.connect(*strats_dct["desc-head_T1w"][i], merge_skulls, f"in{i + 1}")
539+
_connect_node_or_path(wf, merge_brains, strats_dct, "desc-brain_T1w", i)
540+
_connect_node_or_path(wf, merge_skulls, strats_dct, "desc-head_T1w", i)
538541
wf.connect(merge_brains, "out", template_node, "input_brain_list")
539542
wf.connect(merge_skulls, "out", template_node, "input_skull_list")
540543

@@ -1198,3 +1201,11 @@ def func_longitudinal_template_wf(subject_id, strat_list, config):
11981201
workflow.run()
11991202

12001203
return
1204+
1205+
def _connect_node_or_path(wf: pe.Workflow, node: pe.Node, strats_dct: dict[str, list[tuple[pe.Node, str] | str]], key: str, index: int) -> None:
1206+
"""Set input appropriately for either a Node or a path string."""
1207+
input: str = f"in{index + 1}"
1208+
if isinstance(strats_dct[key][index], str):
1209+
setattr(node.inputs, input, strats_dct[key][index])
1210+
else:
1211+
wf.connect(*strats_dct[key][index], node, input)

CPAC/pipeline/nipype_pipeline_engine/engine.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
import re
5252
from copy import deepcopy
5353
from inspect import Parameter, Signature, signature
54-
from typing import ClassVar, Optional, Union
54+
from typing import ClassVar, Optional, TYPE_CHECKING, Union
5555
from nibabel import load
56+
from nipype.interfaces.base.support import InterfaceResult
5657
from nipype.interfaces.utility import Function
5758
from nipype.pipeline import engine as pe
5859
from nipype.pipeline.engine.utils import (
@@ -71,6 +72,8 @@
7172
from traits.trait_handlers import TraitListObject
7273
from CPAC.utils.monitoring.custom_logging import getLogger
7374
from CPAC.utils.typing import DICT
75+
if TYPE_CHECKING:
76+
from CPAC.pipeline.engine import ResourcePool
7477

7578
# set global default mem_gb
7679
DEFAULT_MEM_GB = 2.0
@@ -664,6 +667,19 @@ def _get_dot(
664667
logger.debug("cross connection: %s", dotlist[-1])
665668
return ("\n" + prefix).join(dotlist)
666669

670+
def get_output_path(self, key: str, rpool: "ResourcePool") -> str:
671+
"""Get an output path from an already-run Node."""
672+
_node, _out = rpool.get_data(key)
673+
assert isinstance(_node, pe.Node)
674+
assert isinstance(_out, str)
675+
try:
676+
_run_node: pe.Node = [_ for _ in self.run(updatehash=True).nodes if _.fullname == _node.fullname][0]
677+
except IndexError as index_error:
678+
msg = f"Could not find {key} in {self}'s run Nodes."
679+
raise LookupError(msg) from index_error
680+
_res: InterfaceResult = _run_node.run()
681+
return getattr(_res.outputs, _out)
682+
667683
def _handle_just_in_time_exception(self, node):
668684
# pylint: disable=protected-access
669685
if hasattr(self, '_local_func_scans'):

0 commit comments

Comments
 (0)