Skip to content

Commit 342dae2

Browse files
committed
👔 Clarify longitudinal xfms vs longitudinal warps
1 parent 322d660 commit 342dae2

File tree

5 files changed

+148
-62
lines changed

5 files changed

+148
-62
lines changed

CPAC/longitudinal/robust_template.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
traits,
2929
)
3030
from nipype.interfaces.freesurfer import longitudinal
31+
from nipype.interfaces.freesurfer.preprocess import MRIConvert
3132
from nipype.interfaces.freesurfer.utils import LTAConvert
3233

3334
from CPAC.pipeline import nipype_pipeline_engine as pe
@@ -51,7 +52,7 @@ class RobustTemplateInputSpec(longitudinal.RobustTemplateInputSpec): # noqa: D1
5152

5253
class RobustTemplateOutputSpec(longitudinal.RobustTemplateOutputSpec): # noqa: D101
5354
mapmov = OutputMultiPath(
54-
File(exists=True),
55+
File(),
5556
desc="each input mapped and resampled to longitudinal template",
5657
)
5758

@@ -127,7 +128,7 @@ def mri_robust_template(
127128
average_metric=cfg["longitudinal_template_generation", "average_method"],
128129
auto_detect_sensitivity=True,
129130
mapmov=True,
130-
out_file=f"{name}.nii.gz",
131+
out_file=f"{name}.mgz",
131132
transform_outputs=True,
132133
),
133134
name="mri_robust_template",
@@ -138,12 +139,23 @@ def mri_robust_template(
138139
if isinstance(max_iter, int):
139140
node.set_input("maxit", max_iter)
140141

142+
nifti_template = pe.Node(MRIConvert(out_type="niigz"), name="NIfTI-template")
143+
wf.connect(node, "out_file", nifti_template, "in_file")
144+
145+
nifti_outputs = pe.MapNode(
146+
MRIConvert(), name="NIfTI-mapmov", iterfield=["in_file", "out_file"]
147+
)
148+
wf.connect(node, "mapmov", nifti_outputs, "in_file")
149+
nifti_outputs.set_input(
150+
"out_file", [f"space-longitudinal{i + 1}.nii.gz" for i in range(num_sessions)]
151+
)
152+
141153
convert = pe.MapNode(
142154
LTAConvert(), name="convert-to-FSL", iterfield=["in_lta", "out_fsl"]
143155
)
144156
wf.connect(node, "transform_outputs", convert, "in_lta")
145157
convert.set_input(
146-
"out_fsl", [f"space-longitudinal{i}.mat" for i in range(num_sessions)]
158+
"out_fsl", [f"space-longitudinal{i + 1}.mat" for i in range(num_sessions)]
147159
)
148160

149161
return wf

CPAC/longitudinal/wf/anat.py

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from typing import cast, Optional
2121

22+
from networkx.classes.digraph import DiGraph
2223
from nipype import config as nipype_config
2324
from nipype.interfaces import fsl
2425
from nipype.interfaces.utility import Merge
@@ -28,7 +29,7 @@
2829
from CPAC.longitudinal.wf.utils import (
2930
check_creds_path,
3031
cross_graph_connections,
31-
cross_pool_resources,
32+
get_output_from_graph,
3233
select_session_node,
3334
)
3435
from CPAC.pipeline import nipype_pipeline_engine as pe
@@ -191,17 +192,19 @@ def warp_longitudinal_T1w_to_template(
191192
),
192193
"T1w-brain-template",
193194
],
194-
outputs=[
195-
"label-CSF_mask",
196-
"label-GM_mask",
197-
"label-WM_mask",
198-
"label-CSF_desc-preproc_mask",
199-
"label-GM_desc-preproc_mask",
200-
"label-WM_desc-preproc_mask",
201-
"label-CSF_probseg",
202-
"label-GM_probseg",
203-
"label-WM_probseg",
204-
],
195+
outputs={
196+
"from-longitudinal_to-T1w_mode-image_desc-linear_xfm": {},
197+
"from-longitudinal_to-T1w_mode-image_desc-linear_warp": {},
198+
"label-CSF_mask": {},
199+
"label-GM_mask": {},
200+
"label-WM_mask": {},
201+
"label-CSF_desc-preproc_mask": {},
202+
"label-GM_desc-preproc_mask": {},
203+
"label-WM_desc-preproc_mask": {},
204+
"label-CSF_probseg": {},
205+
"label-GM_probseg": {},
206+
"label-WM_probseg": {},
207+
},
205208
)
206209
def warp_longitudinal_seg_to_T1w(
207210
wf: pe.Workflow,
@@ -211,6 +214,7 @@ def warp_longitudinal_seg_to_T1w(
211214
opt: Optional[str] = None,
212215
) -> NODEBLOCK_RETURN:
213216
"""Transform anatomical images from longitudinal space template space."""
217+
outputs = {}
214218
if strat_pool.check_rpool("from-longitudinal_to-T1w_mode-image_desc-linear_xfm"):
215219
xfm_prov = strat_pool.get_cpac_provenance(
216220
"from-longitudinal_to-T1w_mode-image_desc-linear_xfm"
@@ -233,13 +237,21 @@ def warp_longitudinal_seg_to_T1w(
233237
"in_file",
234238
)
235239
xfm = (invt, "out_file")
240+
outputs["from-longitudinal_to-T1w_mode-image_desc-linear_xfm"] = xfm
241+
if reg_tool != "fsl":
242+
msg = f"`warp_longitudinal_seg_to_T1w` not yet implemented for {reg_tool}."
243+
raise NotImplementedError(msg)
244+
warp = pe.Node(
245+
fsl.ConvertWarp(relwarp=True, out_relwarp=True), name=f"convert_warp_{pipe_num}"
246+
)
247+
wf.connect(*xfm, warp, "postmat")
248+
wf.connect(
249+
*strat_pool.get_data("space-longitudinal_desc-brain_T1w"), warp, "reference"
250+
)
251+
outputs["from-longitudinal_to-T1w_mode-image_desc-linear_warp"] = warp, "out_file"
236252

237253
num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"]
238-
239254
num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"]
240-
241-
outputs = {}
242-
243255
labels = [
244256
"CSF_mask",
245257
"CSF_desc-preproc_mask",
@@ -251,7 +263,6 @@ def warp_longitudinal_seg_to_T1w(
251263
"WM_desc-preproc_mask",
252264
"WM_probseg",
253265
]
254-
255266
for label in labels:
256267
apply_xfm = apply_transform(
257268
f"warp_longitudinal_seg_to_T1w_{label}_{pipe_num}",
@@ -276,11 +287,10 @@ def warp_longitudinal_seg_to_T1w(
276287
node, out = strat_pool.get_data("T1w-brain-template")
277288
wf.connect(node, out, apply_xfm, "inputspec.reference")
278289

279-
wf.connect(*xfm, apply_xfm, "inputspec.transform")
280-
290+
wf.connect(warp, "out_file", apply_xfm, "inputspec.transform")
281291
outputs[f"label-{label}"] = (apply_xfm, "outputspec.output_image")
282292

283-
return (wf, outputs)
293+
return wf, outputs
284294

285295

286296
def anat_longitudinal_wf(
@@ -345,11 +355,13 @@ def anat_longitudinal_wf(
345355
for key in strats_dct.keys():
346356
strats_dct[key].append(cast(tuple[pe.Node, str], rpool.get_data(key)))
347357
if not dry_run:
348-
workflow.run()
358+
workflow_graph: DiGraph = workflow.run()
349359
for key in strats_dct.keys(): # get the outputs from run-nodes
350360
for index, data in enumerate(list(strats_dct[key])):
351361
if isinstance(data, tuple):
352-
strats_dct[key][index] = workflow.get_output(*data)
362+
strats_dct[key][index] = get_output_from_graph(
363+
workflow_graph, *data
364+
)
353365

354366
wf = initialize_nipype_wf(
355367
config,
@@ -409,7 +421,7 @@ def anat_longitudinal_wf(
409421
wf.connect(merge_skulls, "out", wholehead_template_node, "input_skull_list")
410422

411423
case "mri_robust_template":
412-
brain_output = head_output = "mri_robust_template.out_file"
424+
brain_output = head_output = "NIfTI-template.out_file"
413425
brain_template_node = mri_robust_template(
414426
f"mri_robust_template_brain_{subject_id}", config, len(sub_list)
415427
)
@@ -420,7 +432,7 @@ def anat_longitudinal_wf(
420432
merge_brains, "out", brain_template_node, "mri_robust_template.in_files"
421433
)
422434
wf.connect(
423-
merge_brains,
435+
merge_skulls,
424436
"out",
425437
wholehead_template_node,
426438
"mri_robust_template.in_files",
@@ -471,15 +483,14 @@ def anat_longitudinal_wf(
471483
],
472484
)
473485
wf = connect_pipeline(wf, config, rpool, pipeline_blocks)
474-
if not dry_run:
475-
wf.run()
486+
487+
wf_graph: DiGraph | pe.Workflow = (
488+
cast(DiGraph, wf.run()) if not dry_run else cast(pe.Workflow, wf)
489+
)
476490

477491
# now, just write out a copy of the above to each session
478492
config.pipeline_setup["pipeline_name"] = orig_pipe_name
479493
longitudinal_rpool = rpool
480-
cpr = cross_pool_resources(
481-
f"fsl_longitudinal_{subject_id}"
482-
) # "fsl" for check_prov_for_regtool
483494
for i, session in enumerate(sub_list):
484495
unique_id = session["unique_id"]
485496
input_creds_path = check_creds_path(session.get("creds_path"), subject_id)
@@ -504,49 +515,61 @@ def anat_longitudinal_wf(
504515

505516
match config["longitudinal_template_generation", "using"]:
506517
case "C-PAC legacy":
507-
assert isinstance(brain_template_node, pe.Node)
518+
cross_graph_connections(
519+
wf_graph,
520+
ses_wf,
521+
merge_brains,
522+
brain_template_node,
523+
"out",
524+
"input_brain_list",
525+
)
526+
cross_graph_connections(
527+
wf_graph,
528+
ses_wf,
529+
merge_skulls,
530+
brain_template_node,
531+
"out",
532+
"input_skull_list",
533+
)
508534
for input_name, output_name in [
509535
("output_brains", "output_brain_list"),
510536
("warps", "warp_list"),
511537
]:
512538
cross_graph_connections(
513-
wf,
539+
wf_graph,
514540
ses_wf,
515541
brain_template_node,
516542
select_sess,
517543
output_name,
518544
input_name,
519-
dry_run,
520545
)
521546

522547
case "mri_robust_template":
523548
assert isinstance(brain_template_node, pe.Workflow)
524549
assert isinstance(wholehead_template_node, pe.Workflow)
525550
index = i + 1
526-
head_select_sess = select_session_node(unique_id, "-wholehead")
551+
head_select_sess = select_session_node(unique_id, "wholehead")
527552
select_sess.set_input("session", f"space-longitudinal{index}")
528553
head_select_sess.set_input("session", f"space-longitudinal{index}")
529554
for input_name, output_name in [
530-
("output_brains", "mri_robust_template.mapmov"),
555+
("output_brains", "NIfTI-mapmov_.out_file"),
531556
("warps", "convert-to-FSL_.out_fsl"),
532557
]:
533558
cross_graph_connections(
534-
wf,
559+
wf_graph,
535560
ses_wf,
536561
brain_template_node,
537562
select_sess,
538563
output_name,
539564
input_name,
540-
dry_run,
541565
)
542566
cross_graph_connections(
543-
wf,
567+
wf_graph,
544568
ses_wf,
545569
wholehead_template_node,
546570
head_select_sess,
547571
output_name,
548572
input_name,
549-
dry_run,
550573
)
551574

552575
rpool.set_data(
@@ -589,8 +612,20 @@ def anat_longitudinal_wf(
589612
cross_pool_keys = ["from-longitudinal_to-template_mode-image_xfm"]
590613
for key in cross_pool_keys:
591614
node, out = longitudinal_rpool.get_data(key)
592-
cross_graph_connections(wf, ses_wf, node, cpr, out, key, dry_run)
593-
rpool.set_data(key, cpr, key, {}, "", cpr.name)
615+
try:
616+
json_info: dict = longitudinal_rpool.get_json(
617+
key, next(iter(longitudinal_rpool.rpool[key].keys()))
618+
)
619+
except (AttributeError, KeyError, StopIteration):
620+
json_info = {}
621+
rpool.set_data(
622+
key,
623+
node,
624+
out,
625+
json_info,
626+
"",
627+
f"fsl_longitudinal_{subject_id}", # "fsl" for check_prov_for_regtool
628+
)
594629
if not dry_run:
595630
ses_wf.run()
596631

@@ -605,5 +640,5 @@ def anat_longitudinal_wf(
605640

606641
# this is going to run multiple times!
607642
# once for every strategy!
608-
if not dry_run:
643+
if not dry_run: # check select_sess
609644
ses_wf.run()

0 commit comments

Comments
 (0)