2222from pathlib import Path
2323
2424
25- def generate_simple_before_and_after (subjects : dict , output_dir ):
25+ def preprocess_single_subject (s , output_dir ):
26+ """Preprocess a single subject's images (for parallel processing)."""
27+ temp_dir = os .path .join (output_dir , "temp_3d_images" )
28+
29+ # Preprocess original image
30+ orig_result = load_and_preprocess_image (s ["orig_path" ])
31+ if isinstance (orig_result , nib .Nifti1Image ):
32+ # Need to save the averaged image
33+ orig_3d_path = os .path .join (temp_dir , f"orig_{ s ['id' ]} .nii.gz" )
34+ nib .save (orig_result , orig_3d_path )
35+ orig_img = orig_result
36+ else :
37+ # Already 3D, use original path and load image
38+ orig_3d_path = orig_result
39+ orig_img = nib .load (orig_result )
40+
41+ # Preprocess defaced image
42+ defaced_result = load_and_preprocess_image (s ["defaced_path" ])
43+ if isinstance (defaced_result , nib .Nifti1Image ):
44+ # Need to save the averaged image
45+ defaced_3d_path = os .path .join (temp_dir , f"defaced_{ s ['id' ]} .nii.gz" )
46+ nib .save (defaced_result , defaced_3d_path )
47+ defaced_img = defaced_result
48+ else :
49+ # Already 3D, use original path and load image
50+ defaced_3d_path = defaced_result
51+ defaced_img = nib .load (defaced_result )
52+
53+ # Create new subject dict with preprocessed paths (update paths to 3D versions)
54+ preprocessed_subject = {
55+ "id" : s ["id" ],
56+ "orig_path" : orig_3d_path , # Update to 3D path
57+ "defaced_path" : defaced_3d_path , # Update to 3D path
58+ "orig_img" : orig_img , # Keep loaded image for direct use
59+ "defaced_img" : defaced_img , # Keep loaded image for direct use
60+ }
61+
62+ print (f" Preprocessed { s ['id' ]} " )
63+ return preprocessed_subject
64+
65+
66+ def preprocess_images (subjects : dict , output_dir , n_jobs = None ):
67+ """Preprocess all images once: load and convert 4D to 3D if needed."""
68+ print ("Preprocessing images (4D→3D conversion)..." )
69+
70+ # Create temp directory
71+ temp_dir = os .path .join (output_dir , "temp_3d_images" )
72+ os .makedirs (temp_dir , exist_ok = True )
73+
74+ # Set number of jobs for parallel processing
75+ if n_jobs is None :
76+ n_jobs = mp .cpu_count ()
77+ print (f"Using { n_jobs } parallel processes for preprocessing" )
78+
79+ # Process subjects in parallel
80+ with mp .Pool (processes = n_jobs ) as pool :
81+ # Create a partial function with the output_dir fixed
82+ preprocess_func = partial (preprocess_single_subject , output_dir = output_dir )
83+
84+ # Process all subjects in parallel
85+ preprocessed_subjects = pool .map (preprocess_func , subjects )
86+
87+ print (f"Preprocessed { len (preprocessed_subjects )} subjects" )
88+ return preprocessed_subjects
89+
90+
91+ def generate_simple_before_and_after (preprocessed_subjects : dict , output_dir ):
2692 if not output_dir :
2793 output_dir = TemporaryDirectory ()
2894 wf = Workflow (
@@ -32,12 +98,24 @@ def generate_simple_before_and_after(subjects: dict, output_dir):
3298 # Create a list to store all nodes
3399 nodes = []
34100
35- for s in subjects :
101+ for s in preprocessed_subjects :
36102 # only run this on the T1w images for now
37103 if "T1w" in s ["orig_path" ]:
38104 o_path = Path (s ["orig_path" ])
39- # Create a valid node name by replacing invalid characters
40- valid_name = f"before_after_{ s ['id' ].replace ('-' , '_' ).replace ('_' , '' )} "
105+ # Create a valid node name by replacing invalid characters but preserving session info
106+ # Use the full path to ensure uniqueness
107+ path_parts = s ["orig_path" ].split (os .sep )
108+ subject_part = next (
109+ (p for p in path_parts if p .startswith ("sub-" )), s ["id" ]
110+ )
111+ session_part = next ((p for p in path_parts if p .startswith ("ses-" )), "" )
112+
113+ if session_part :
114+ valid_name = f"before_after_{ subject_part } _{ session_part } " .replace (
115+ "-" , "_"
116+ )
117+ else :
118+ valid_name = f"before_after_{ subject_part } " .replace ("-" , "_" )
41119 node = Node (
42120 SimpleBeforeAfterRPT (
43121 before = s ["orig_path" ],
@@ -52,10 +130,14 @@ def generate_simple_before_and_after(subjects: dict, output_dir):
52130
53131 # Add all nodes to the workflow
54132 wf .add_nodes (nodes )
55- wf .run (plugin = "MultiProc" , plugin_args = {"n_procs" : mp .cpu_count ()})
56133
57- # Collect SVG files and move them to images folder
58- collect_svg_reports (wf , output_dir )
134+ # Only run if we have nodes to process
135+ if nodes :
136+ wf .run (plugin = "MultiProc" , plugin_args = {"n_procs" : mp .cpu_count ()})
137+ # Collect SVG files and move them to images folder
138+ collect_svg_reports (wf , output_dir )
139+ else :
140+ print ("No T1w images found for SVG report generation" )
59141
60142
61143def collect_svg_reports (wf , output_dir ):
@@ -343,7 +425,8 @@ def create_overlay_gif(image_files, subject_id, output_dir):
343425
344426
345427def load_and_preprocess_image (img_path ):
346- """Load image and take mean if it has more than 3 dimensions."""
428+ """Load image and take mean if it has more than 3 dimensions.
429+ Returns nibabel image if averaging was needed, otherwise returns original path."""
347430 img = nib .load (img_path )
348431
349432 # Check if image has more than 3 dimensions
@@ -356,13 +439,14 @@ def load_and_preprocess_image(img_path):
356439 mean_data = np .mean (data , axis = 3 )
357440 # Create new 3D image
358441 img = nib .Nifti1Image (mean_data , img .affine , img .header )
359-
360- return img
442+ return img # Return nibabel image object
443+ else :
444+ return img_path # Return original path if already 3D
361445
362446
363447def create_comparison_html (
364- orig_path ,
365- defaced_path ,
448+ orig_img ,
449+ defaced_img ,
366450 subject_id ,
367451 output_dir ,
368452 display_mode = "side-by-side" ,
@@ -371,21 +455,15 @@ def create_comparison_html(
371455 """Create HTML comparison page for a subject using nilearn ortho views."""
372456
373457 # Get basenames for display
374- orig_basename = os . path . basename ( orig_path )
375- defaced_basename = os . path . basename ( defaced_path )
458+ orig_basename = f"orig_ { subject_id } .nii.gz"
459+ defaced_basename = f"defaced_ { subject_id } .nii.gz"
376460
377461 # Generate images and get their filenames
378462 image_files = []
379- for label , img_path , cmap in [
380- ("original" , orig_path , "hot" ), # Colored for original
381- ("defaced" , defaced_path , "gray" ), # Grey for defaced
463+ for label , img , basename , cmap in [
464+ ("original" , orig_img , orig_basename , "hot" ), # Colored for original
465+ ("defaced" , defaced_img , defaced_basename , "gray" ), # Grey for defaced
382466 ]:
383- # Get the basename for display
384- basename = os .path .basename (img_path )
385-
386- # Load and preprocess image (handle 4D if needed)
387- img = load_and_preprocess_image (img_path )
388-
389467 # save image to temp folder for later loading
390468
391469 # Create single sagittal slice using matplotlib directly
@@ -659,25 +737,14 @@ def process_subject(subject, output_dir, size="compact"):
659737 """Process a single subject (for parallel processing)."""
660738 print (f"Processing { subject ['id' ]} ..." )
661739 try :
662- subject_temp_dir = tempfile .TemporaryDirectory ()
663- # load each image file then save it to temp
664- original_image = load_and_preprocess_image (subject ["orig_path" ])
665- defaced_image = load_and_preprocess_image (subject ["defaced_path" ])
666- original_image = nib .Nifti1Image (original_image .get_fdata (), original_image .affine , original_image .header )
667- defaced_image = nib .Nifti1Image (defaced_image .get_fdata (), defaced_image .affine , defaced_image .header )
668-
669- nib .save (original_image , Path (subject_temp_dir .name ) / Path (subject ["orig_path" ]).name )
670- nib .save (defaced_image , Path (subject_temp_dir .name ) / Path (subject ["defaced_path" ]).name )
671-
672740 comparison_file = create_comparison_html (
673- Path ( subject_temp_dir . name ) / Path ( subject ["orig_path" ]). name ,
674- Path ( subject_temp_dir . name ) / Path ( subject ["defaced_path" ]). name ,
741+ subject ["orig_img" ] ,
742+ subject ["defaced_img" ] ,
675743 subject ["id" ],
676744 output_dir ,
677745 "side-by-side" , # Always generate side-by-side for individual pages
678746 size ,
679747 )
680-
681748 print (f" Completed: { subject ['id' ]} " )
682749 return comparison_file
683750 except Exception as e :
@@ -687,8 +754,22 @@ def process_subject(subject, output_dir, size="compact"):
687754
688755def build_subjects_from_datasets (orig_dir , defaced_dir ):
689756 """Build subject list with file paths."""
690- orig_files = glob (os .path .join (orig_dir , "**" , "*.nii*" ), recursive = True )
691- defaced_files = glob (os .path .join (defaced_dir , "**" , "*.nii*" ), recursive = True )
757+
758+ # Get all NIfTI files but exclude derivatives and workflow folders
759+ def get_nifti_files (directory ):
760+ all_files = glob (os .path .join (directory , "**" , "*.nii*" ), recursive = True )
761+ # Filter out files in derivatives, workflow, or other processing folders
762+ filtered_files = []
763+ for file_path in all_files :
764+ # Skip files in derivatives, workflow, or processing-related directories
765+ path_parts = file_path .split (os .sep )
766+ skip_dirs = ["derivatives" , "work" , "wf" , "tmp" , "temp" , "scratch" , "cache" ]
767+ if not any (skip_dir in path_parts for skip_dir in skip_dirs ):
768+ filtered_files .append (file_path )
769+ return filtered_files
770+
771+ orig_files = get_nifti_files (orig_dir )
772+ defaced_files = get_nifti_files (defaced_dir )
692773
693774 def strip_ext (path ):
694775 base = os .path .basename (path )
@@ -751,8 +832,8 @@ def create_side_by_side_index_html(subjects, output_dir, size="compact"):
751832 comparisons_html = ""
752833 for subject in subjects :
753834 subject_id = subject ["id" ]
754- orig_basename = os . path . basename ( subject [ "orig_path" ])
755- defaced_basename = os . path . basename ( subject [ "defaced_path" ])
835+ orig_basename = f"orig_ { subject_id } .nii.gz"
836+ defaced_basename = f"defaced_ { subject_id } .nii.gz"
756837
757838 # Check if the PNG files exist
758839 orig_png = f"images/original_{ subject_id } .png"
@@ -1157,13 +1238,18 @@ def main():
11571238 print (f" - { s ['id' ]} " )
11581239 exit (1 )
11591240
1160- # create nireports svg's for comparison
1161- generate_simple_before_and_after (subjects = subjects , output_dir = output_dir )
1162-
11631241 # Set number of jobs for parallel processing
11641242 n_jobs = args .n_jobs if args .n_jobs else mp .cpu_count ()
11651243 print (f"Using { n_jobs } parallel processes" )
11661244
1245+ # Preprocess all images once (4D→3D conversion)
1246+ preprocessed_subjects = preprocess_images (subjects , output_dir , n_jobs )
1247+
1248+ # create nireports svg's for comparison
1249+ generate_simple_before_and_after (
1250+ preprocessed_subjects = preprocessed_subjects , output_dir = output_dir
1251+ )
1252+
11671253 # Process subjects in parallel
11681254 print ("Generating comparisons..." )
11691255 with mp .Pool (processes = n_jobs ) as pool :
@@ -1175,15 +1261,19 @@ def main():
11751261 )
11761262
11771263 # Process all subjects in parallel
1178- results = pool .map (process_func , subjects )
1264+ results = pool .map (process_func , preprocessed_subjects )
11791265
11801266 # Count successful results
11811267 successful = [r for r in results if r is not None ]
1182- print (f"Successfully processed { len (successful )} out of { len (subjects )} subjects" )
1268+ print (
1269+ f"Successfully processed { len (successful )} out of { len (preprocessed_subjects )} subjects"
1270+ )
11831271
11841272 # Create both HTML files
1185- side_by_side_file = create_side_by_side_index_html (subjects , output_dir , args .size )
1186- animated_file = create_gif_index_html (subjects , output_dir , args .size )
1273+ side_by_side_file = create_side_by_side_index_html (
1274+ preprocessed_subjects , output_dir , args .size
1275+ )
1276+ animated_file = create_gif_index_html (preprocessed_subjects , output_dir , args .size )
11871277
11881278 # Create a simple index that links to both
11891279 index_html = f"""
@@ -1241,7 +1331,7 @@ def main():
12411331 <a href="SimpleBeforeAfterRPT.html" class="link-button">SVG Reports View</a>
12421332
12431333 <p style="margin-top: 30px; color: #999; font-size: 14px;">
1244- Generated with { len (subjects )} subjects
1334+ Generated with { len (preprocessed_subjects )} subjects
12451335 </p>
12461336 </div>
12471337 </body>
0 commit comments