Skip to content

Commit 1a90cc0

Browse files
committed
👔 Make cross-graph and cross-pool connections
1 parent a5d2c60 commit 1a90cc0

File tree

5 files changed

+117
-54
lines changed

5 files changed

+117
-54
lines changed

CPAC/longitudinal/robust_template.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,21 @@ def _list_outputs(self):
113113

114114
def mri_robust_template(name: str, cfg: Configuration) -> pe.Node:
115115
"""Return a Node to run `mri_robust_template` with common options."""
116-
node = pe.Node(RobustTemplate(), name=name)
117-
node.set_input("mapmov", True)
118-
node.set_input("transform_outputs", True)
119-
node.set_input(
120-
"average_metric", cfg["longitudinal_template_generation", "average_method"]
116+
node = pe.Node(
117+
RobustTemplate(
118+
affine=cfg["longitudinal_template_generation", "dof"] == 12, # noqa: PLR2004
119+
average_metric=cfg["longitudinal_template_generation", "average_method"],
120+
auto_detect_sensitivity=True,
121+
mapmov=True,
122+
out_file=f"{name}.nii.gz",
123+
transform_outputs=True,
124+
),
125+
name=name,
121126
)
122-
node.set_input("affine", cfg["longitudinal_template_generation", "dof"] == 12) # noqa: PLR2004
123127
max_iter = cast(
124128
int | Literal["default"], cfg["longitudinal_template_generation", "max_iter"]
125129
)
126130
if isinstance(max_iter, int):
127131
node.set_input("maxit", max_iter)
128-
node.set_input("auto_detect_sensitivity", True)
129132

130133
return node

CPAC/longitudinal/wf/anat.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525

2626
from CPAC.longitudinal.preproc import subject_specific_template
2727
from CPAC.longitudinal.robust_template import mri_robust_template
28-
from CPAC.longitudinal.wf.utils import check_creds_path, select_session_node
28+
from CPAC.longitudinal.wf.utils import (
29+
check_creds_path,
30+
cross_graph_connections,
31+
cross_pool_resources,
32+
select_session_node,
33+
)
2934
from CPAC.pipeline import nipype_pipeline_engine as pe
3035
from CPAC.pipeline.cpac_pipeline import (
3136
build_anat_preproc_stack,
@@ -345,7 +350,7 @@ def anat_longitudinal_wf(
345350
for key in strats_dct.keys(): # get the outputs from run-nodes
346351
for index, data in enumerate(list(strats_dct[key])):
347352
if isinstance(data, tuple):
348-
strats_dct[key][index] = workflow.get_output_path(*data)
353+
strats_dct[key][index] = workflow.get_output(*data)
349354

350355
wf = initialize_nipype_wf(
351356
config,
@@ -460,24 +465,24 @@ def anat_longitudinal_wf(
460465
],
461466
)
462467
wf = connect_pipeline(wf, config, rpool, pipeline_blocks)
463-
464468
if not dry_run:
465469
wf.run()
466470

467471
# now, just write out a copy of the above to each session
468472
config.pipeline_setup["pipeline_name"] = orig_pipe_name
473+
longitudinal_rpool = rpool
474+
cpr = cross_pool_resources(f"longitudinal_{subject_id}")
469475
for session in sub_list:
470476
unique_id = session["unique_id"]
471477
input_creds_path = check_creds_path(session.get("creds_path"), subject_id)
472478

473-
wf = initialize_nipype_wf(config, subject_id, unique_id)
474-
475-
wf, rpool = initiate_rpool(wf, config, session)
479+
ses_wf = initialize_nipype_wf(config, subject_id, unique_id)
476480

481+
ses_wf, rpool = initiate_rpool(ses_wf, config, session)
477482
config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}"
478483
if "derivatives_dir" in session:
479-
rpool = ingress_output_dir(
480-
wf,
484+
ses_wf, rpool = ingress_output_dir(
485+
ses_wf,
481486
config,
482487
rpool,
483488
long_id,
@@ -491,32 +496,45 @@ def anat_longitudinal_wf(
491496

492497
match config["longitudinal_template_generation", "using"]:
493498
case "C-PAC legacy":
494-
wf.connect(
495-
brain_template_node,
496-
"output_brain_list",
497-
select_sess,
498-
"output_brains",
499-
)
500-
wf.connect(brain_template_node, "warp_list", select_sess, "warps")
499+
for input_name, output_name in [
500+
("output_brains", "output_brain_list"),
501+
("warps", "warp_list"),
502+
]:
503+
cross_graph_connections(
504+
wf,
505+
ses_wf,
506+
brain_template_node,
507+
select_sess,
508+
output_name,
509+
input_name,
510+
dry_run,
511+
)
501512

502513
case "mri_robust_template":
503-
wf.connect(brain_template_node, "mapmov", select_sess, "output_brains")
504-
wf.connect(
505-
brain_template_node, "transform_outputs", select_sess, "warps"
506-
)
507514
head_select_sess = select_session_node(unique_id, "wholehead")
508-
wf.connect(
509-
wholehead_template_node,
510-
"mapmov",
511-
head_select_sess,
512-
"output_brains",
513-
)
514-
wf.connect(
515-
wholehead_template_node,
516-
"transform_outputs",
517-
head_select_sess,
518-
"warps",
519-
)
515+
for input_name, output_name in [
516+
("output_brains", "mapmov"),
517+
("warps", "transform_outputs"),
518+
]:
519+
cross_graph_connections(
520+
wf,
521+
ses_wf,
522+
brain_template_node,
523+
select_sess,
524+
output_name,
525+
input_name,
526+
dry_run,
527+
)
528+
cross_graph_connections(
529+
wf,
530+
ses_wf,
531+
wholehead_template_node,
532+
head_select_sess,
533+
output_name,
534+
input_name,
535+
dry_run,
536+
)
537+
520538
rpool.set_data(
521539
"space-longitudinal_desc-head_T1w",
522540
head_select_sess,
@@ -553,29 +571,25 @@ def anat_longitudinal_wf(
553571

554572
config.pipeline_setup["pipeline_name"] = orig_pipe_name
555573
excl = ["space-template_desc-brain_T1w", "space-T1w_desc-brain_mask"]
556-
rpool.gather_pipes(wf, config, add_excl=excl)
574+
rpool.gather_pipes(ses_wf, config, add_excl=excl)
575+
cross_pool_keys = ["from-longitudinal_to-template_mode-image_xfm"]
576+
for key in cross_pool_keys:
577+
node, out = longitudinal_rpool.get_data(key)
578+
cross_graph_connections(wf, ses_wf, node, cpr, out, key, dry_run)
579+
rpool.set_data(key, cpr, key, {}, "", cpr.name)
557580
if not dry_run:
558-
wf.run()
559-
560-
# begin single-session stuff again
561-
for session in sub_list:
562-
unique_id = session["unique_id"]
563-
input_creds_path = check_creds_path(session.get("creds_path"), subject_id)
564-
565-
wf = initialize_nipype_wf(config, subject_id, unique_id)
566-
567-
wf, rpool = initiate_rpool(wf, config, session)
581+
ses_wf.run()
568582

569583
pipeline_blocks = [
570584
warp_longitudinal_T1w_to_template,
571585
warp_longitudinal_seg_to_T1w,
572586
]
573587

574-
wf = connect_pipeline(wf, config, rpool, pipeline_blocks)
588+
ses_wf = connect_pipeline(ses_wf, config, rpool, pipeline_blocks)
575589

576-
rpool.gather_pipes(wf, config)
590+
rpool.gather_pipes(ses_wf, config)
577591

578592
# this is going to run multiple times!
579593
# once for every strategy!
580594
if not dry_run:
581-
wf.run()
595+
ses_wf.run()

CPAC/longitudinal/wf/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from pathlib import Path
2121
from typing import Optional
2222

23+
from nipype.interfaces.utility import IdentityInterface
24+
2325
from CPAC.pipeline import nipype_pipeline_engine as pe
2426
from CPAC.utils.interfaces.function import Function
2527

@@ -38,6 +40,43 @@ def check_creds_path(creds_path: Optional[str], subject_id: str) -> Optional[str
3840
return None
3941

4042

43+
def cross_graph_connections(
44+
wf1: pe.Workflow,
45+
wf2: pe.Workflow,
46+
node1: pe.Node,
47+
node2: pe.Node,
48+
output_name: str,
49+
input_name: str,
50+
dry_run: bool,
51+
) -> None:
52+
"""Make cross-graph connections appropriate to dry-run status.
53+
54+
Parameters
55+
----------
56+
wf1
57+
The graph that runs first
58+
59+
wf2
60+
The graph that runs second
61+
62+
node1
63+
The node from ``wf1``
64+
65+
node2
66+
The node from ``wf2``
67+
68+
output_name
69+
The output name from ``node1``
70+
71+
input_name
72+
The input name from ``node2``
73+
"""
74+
if dry_run:
75+
wf2.connect(node1, output_name, node2, input_name)
76+
else:
77+
node2.set_input(input_name, wf1.get_output(node1, output_name))
78+
79+
4180
def select_session(
4281
session: str, output_brains: list[str], warps: list[str]
4382
) -> tuple[Optional[str], Optional[str]]:
@@ -74,3 +113,11 @@ def select_session_node(unique_id: str, suffix: str = "") -> pe.Node:
74113
)
75114
select_sess.inputs.session = unique_id
76115
return select_sess
116+
117+
118+
def cross_pool_resources(name: str) -> pe.Node:
119+
"""Return an IdentityInterface for cross-pool resources."""
120+
return pe.Node(
121+
IdentityInterface(fields=["from-longitudinal_to-template_mode-image_xfm"]),
122+
name=name,
123+
)

CPAC/pipeline/nipype_pipeline_engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def _get_dot(
717717
WFLOGGER.debug("cross connection: %s", dotlist[-1])
718718
return ("\n" + prefix).join(dotlist)
719719

720-
def get_output_path(self, node: pe.Node, out: str) -> str:
720+
def get_output(self, node: pe.Node, out: str) -> Any:
721721
"""Get an output path from an already-run Node."""
722722
try:
723723
_run_node: pe.Node = next(

CPAC/seg_preproc/seg_preproc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,6 @@ def tissue_seg_fsl_fast(wf, cfg, strat_pool, pipe_num, opt=None):
519519
# triggered by 'segments' boolean input (-g or --segments)
520520
# 'probability_maps' output is a list of individual probability maps
521521
# triggered by 'probability_maps' boolean input (-p)
522-
523522
segment = pe.Node(
524523
interface=fsl.FAST(),
525524
name=f"segment_{pipe_num}",

0 commit comments

Comments
 (0)