1919
2020from typing import cast , Optional
2121
22+ from networkx .classes .digraph import DiGraph
2223from nipype import config as nipype_config
2324from nipype .interfaces import fsl
2425from nipype .interfaces .utility import Merge
2829from CPAC .longitudinal .wf .utils import (
2930 check_creds_path ,
3031 cross_graph_connections ,
31- cross_pool_resources ,
32+ get_output_from_graph ,
3233 select_session_node ,
3334)
3435from CPAC .pipeline import nipype_pipeline_engine as pe
@@ -191,17 +192,19 @@ def warp_longitudinal_T1w_to_template(
191192 ),
192193 "T1w-brain-template" ,
193194 ],
194- outputs = [
195- "label-CSF_mask" ,
196- "label-GM_mask" ,
197- "label-WM_mask" ,
198- "label-CSF_desc-preproc_mask" ,
199- "label-GM_desc-preproc_mask" ,
200- "label-WM_desc-preproc_mask" ,
201- "label-CSF_probseg" ,
202- "label-GM_probseg" ,
203- "label-WM_probseg" ,
204- ],
195+ outputs = {
196+ "from-longitudinal_to-T1w_mode-image_desc-linear_xfm" : {},
197+ "from-longitudinal_to-T1w_mode-image_desc-linear_warp" : {},
198+ "label-CSF_mask" : {},
199+ "label-GM_mask" : {},
200+ "label-WM_mask" : {},
201+ "label-CSF_desc-preproc_mask" : {},
202+ "label-GM_desc-preproc_mask" : {},
203+ "label-WM_desc-preproc_mask" : {},
204+ "label-CSF_probseg" : {},
205+ "label-GM_probseg" : {},
206+ "label-WM_probseg" : {},
207+ },
205208)
206209def warp_longitudinal_seg_to_T1w (
207210 wf : pe .Workflow ,
@@ -211,6 +214,7 @@ def warp_longitudinal_seg_to_T1w(
211214 opt : Optional [str ] = None ,
212215) -> NODEBLOCK_RETURN :
213216 """Transform anatomical images from longitudinal space template space."""
217+ outputs = {}
214218 if strat_pool .check_rpool ("from-longitudinal_to-T1w_mode-image_desc-linear_xfm" ):
215219 xfm_prov = strat_pool .get_cpac_provenance (
216220 "from-longitudinal_to-T1w_mode-image_desc-linear_xfm"
@@ -233,13 +237,21 @@ def warp_longitudinal_seg_to_T1w(
233237 "in_file" ,
234238 )
235239 xfm = (invt , "out_file" )
240+ outputs ["from-longitudinal_to-T1w_mode-image_desc-linear_xfm" ] = xfm
241+ if reg_tool != "fsl" :
242+ msg = f"`warp_longitudinal_seg_to_T1w` not yet implemented for { reg_tool } ."
243+ raise NotImplementedError (msg )
244+ warp = pe .Node (
245+ fsl .ConvertWarp (relwarp = True , out_relwarp = True ), name = f"convert_warp_{ pipe_num } "
246+ )
247+ wf .connect (* xfm , warp , "postmat" )
248+ wf .connect (
249+ * strat_pool .get_data ("space-longitudinal_desc-brain_T1w" ), warp , "reference"
250+ )
251+ outputs ["from-longitudinal_to-T1w_mode-image_desc-linear_warp" ] = warp , "out_file"
236252
237253 num_cpus = cfg .pipeline_setup ["system_config" ]["max_cores_per_participant" ]
238-
239254 num_ants_cores = cfg .pipeline_setup ["system_config" ]["num_ants_threads" ]
240-
241- outputs = {}
242-
243255 labels = [
244256 "CSF_mask" ,
245257 "CSF_desc-preproc_mask" ,
@@ -251,7 +263,6 @@ def warp_longitudinal_seg_to_T1w(
251263 "WM_desc-preproc_mask" ,
252264 "WM_probseg" ,
253265 ]
254-
255266 for label in labels :
256267 apply_xfm = apply_transform (
257268 f"warp_longitudinal_seg_to_T1w_{ label } _{ pipe_num } " ,
@@ -276,11 +287,10 @@ def warp_longitudinal_seg_to_T1w(
276287 node , out = strat_pool .get_data ("T1w-brain-template" )
277288 wf .connect (node , out , apply_xfm , "inputspec.reference" )
278289
279- wf .connect (* xfm , apply_xfm , "inputspec.transform" )
280-
290+ wf .connect (warp , "out_file" , apply_xfm , "inputspec.transform" )
281291 outputs [f"label-{ label } " ] = (apply_xfm , "outputspec.output_image" )
282292
283- return ( wf , outputs )
293+ return wf , outputs
284294
285295
286296def anat_longitudinal_wf (
@@ -345,11 +355,13 @@ def anat_longitudinal_wf(
345355 for key in strats_dct .keys ():
346356 strats_dct [key ].append (cast (tuple [pe .Node , str ], rpool .get_data (key )))
347357 if not dry_run :
348- workflow .run ()
358+ workflow_graph : DiGraph = workflow .run ()
349359 for key in strats_dct .keys (): # get the outputs from run-nodes
350360 for index , data in enumerate (list (strats_dct [key ])):
351361 if isinstance (data , tuple ):
352- strats_dct [key ][index ] = workflow .get_output (* data )
362+ strats_dct [key ][index ] = get_output_from_graph (
363+ workflow_graph , * data
364+ )
353365
354366 wf = initialize_nipype_wf (
355367 config ,
@@ -409,7 +421,7 @@ def anat_longitudinal_wf(
409421 wf .connect (merge_skulls , "out" , wholehead_template_node , "input_skull_list" )
410422
411423 case "mri_robust_template" :
412- brain_output = head_output = "mri_robust_template .out_file"
424+ brain_output = head_output = "NIfTI-template .out_file"
413425 brain_template_node = mri_robust_template (
414426 f"mri_robust_template_brain_{ subject_id } " , config , len (sub_list )
415427 )
@@ -420,7 +432,7 @@ def anat_longitudinal_wf(
420432 merge_brains , "out" , brain_template_node , "mri_robust_template.in_files"
421433 )
422434 wf .connect (
423- merge_brains ,
435+ merge_skulls ,
424436 "out" ,
425437 wholehead_template_node ,
426438 "mri_robust_template.in_files" ,
@@ -471,15 +483,14 @@ def anat_longitudinal_wf(
471483 ],
472484 )
473485 wf = connect_pipeline (wf , config , rpool , pipeline_blocks )
474- if not dry_run :
475- wf .run ()
486+
487+ wf_graph : DiGraph | pe .Workflow = (
488+ cast (DiGraph , wf .run ()) if not dry_run else cast (pe .Workflow , wf )
489+ )
476490
477491 # now, just write out a copy of the above to each session
478492 config .pipeline_setup ["pipeline_name" ] = orig_pipe_name
479493 longitudinal_rpool = rpool
480- cpr = cross_pool_resources (
481- f"fsl_longitudinal_{ subject_id } "
482- ) # "fsl" for check_prov_for_regtool
483494 for i , session in enumerate (sub_list ):
484495 unique_id = session ["unique_id" ]
485496 input_creds_path = check_creds_path (session .get ("creds_path" ), subject_id )
@@ -504,49 +515,61 @@ def anat_longitudinal_wf(
504515
505516 match config ["longitudinal_template_generation" , "using" ]:
506517 case "C-PAC legacy" :
507- assert isinstance (brain_template_node , pe .Node )
518+ cross_graph_connections (
519+ wf_graph ,
520+ ses_wf ,
521+ merge_brains ,
522+ brain_template_node ,
523+ "out" ,
524+ "input_brain_list" ,
525+ )
526+ cross_graph_connections (
527+ wf_graph ,
528+ ses_wf ,
529+ merge_skulls ,
530+ brain_template_node ,
531+ "out" ,
532+ "input_skull_list" ,
533+ )
508534 for input_name , output_name in [
509535 ("output_brains" , "output_brain_list" ),
510536 ("warps" , "warp_list" ),
511537 ]:
512538 cross_graph_connections (
513- wf ,
539+ wf_graph ,
514540 ses_wf ,
515541 brain_template_node ,
516542 select_sess ,
517543 output_name ,
518544 input_name ,
519- dry_run ,
520545 )
521546
522547 case "mri_robust_template" :
523548 assert isinstance (brain_template_node , pe .Workflow )
524549 assert isinstance (wholehead_template_node , pe .Workflow )
525550 index = i + 1
526- head_select_sess = select_session_node (unique_id , "- wholehead" )
551+ head_select_sess = select_session_node (unique_id , "wholehead" )
527552 select_sess .set_input ("session" , f"space-longitudinal{ index } " )
528553 head_select_sess .set_input ("session" , f"space-longitudinal{ index } " )
529554 for input_name , output_name in [
530- ("output_brains" , "mri_robust_template.mapmov " ),
555+ ("output_brains" , "NIfTI-mapmov_.out_file " ),
531556 ("warps" , "convert-to-FSL_.out_fsl" ),
532557 ]:
533558 cross_graph_connections (
534- wf ,
559+ wf_graph ,
535560 ses_wf ,
536561 brain_template_node ,
537562 select_sess ,
538563 output_name ,
539564 input_name ,
540- dry_run ,
541565 )
542566 cross_graph_connections (
543- wf ,
567+ wf_graph ,
544568 ses_wf ,
545569 wholehead_template_node ,
546570 head_select_sess ,
547571 output_name ,
548572 input_name ,
549- dry_run ,
550573 )
551574
552575 rpool .set_data (
@@ -589,8 +612,20 @@ def anat_longitudinal_wf(
589612 cross_pool_keys = ["from-longitudinal_to-template_mode-image_xfm" ]
590613 for key in cross_pool_keys :
591614 node , out = longitudinal_rpool .get_data (key )
592- cross_graph_connections (wf , ses_wf , node , cpr , out , key , dry_run )
593- rpool .set_data (key , cpr , key , {}, "" , cpr .name )
615+ try :
616+ json_info : dict = longitudinal_rpool .get_json (
617+ key , next (iter (longitudinal_rpool .rpool [key ].keys ()))
618+ )
619+ except (AttributeError , KeyError , StopIteration ):
620+ json_info = {}
621+ rpool .set_data (
622+ key ,
623+ node ,
624+ out ,
625+ json_info ,
626+ "" ,
627+ f"fsl_longitudinal_{ subject_id } " , # "fsl" for check_prov_for_regtool
628+ )
594629 if not dry_run :
595630 ses_wf .run ()
596631
@@ -605,5 +640,5 @@ def anat_longitudinal_wf(
605640
606641 # this is going to run multiple times!
607642 # once for every strategy!
608- if not dry_run :
643+ if not dry_run : # check select_sess
609644 ses_wf .run ()
0 commit comments