Skip to content

Commit 4d69a07

Browse files
committed
fixed bug w/ overwriting original images
1 parent 6ba6c60 commit 4d69a07

File tree

2 files changed

+162
-71
lines changed

2 files changed

+162
-71
lines changed

petdeface/petdeface.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def deface(args: Union[dict, argparse.Namespace]) -> None:
297297
write_out_dataset_description_json(args.bids_dir)
298298

299299
# remove temp outputs - this is commented out to enable easier testing for now
300-
if str(os.getenv("DEBUG", "false")).lower() != "true":
300+
if str(os.getenv("PETDEFACE_DEBUG", "false")).lower() != "true":
301301
shutil.rmtree(os.path.join(output_dir, "petdeface_wf"))
302302

303303
return {"subjects": subjects}
@@ -671,21 +671,21 @@ def wrap_up_defacing(
671671
should_exclude = False
672672
for excluded_subject in participant_label_exclude:
673673
# Handle both cases: excluded_subject with or without 'sub-' prefix
674-
if excluded_subject.startswith('sub-'):
674+
if excluded_subject.startswith("sub-"):
675675
subject_pattern = f"/{excluded_subject}/"
676676
subject_pattern_underscore = f"/{excluded_subject}_"
677677
else:
678678
subject_pattern = f"/sub-{excluded_subject}/"
679679
subject_pattern_underscore = f"/sub-{excluded_subject}_"
680-
680+
681681
if subject_pattern in entry or subject_pattern_underscore in entry:
682682
should_exclude = True
683683
break
684-
684+
685685
# Skip excluded subject files, but copy everything else (including dataset-level files)
686686
if should_exclude:
687687
continue
688-
688+
689689
copy_path = entry.replace(str(path_to_dataset), str(final_destination))
690690
pathlib.Path(copy_path).parent.mkdir(
691691
parents=True, exist_ok=True, mode=0o775
@@ -730,7 +730,7 @@ def wrap_up_defacing(
730730
desc="defaced",
731731
return_type="file",
732732
)
733-
if str(os.getenv("DEBUG", "false")).lower() != "true":
733+
if str(os.getenv("PETDEFAC_DEBUG", "false")).lower() != "true":
734734
for extraneous in derivatives:
735735
os.remove(extraneous)
736736

@@ -741,15 +741,16 @@ def wrap_up_defacing(
741741
"placement must be one of ['adjacent', 'inplace', 'derivatives']"
742742
)
743743

744-
# clean up any errantly leftover files with globe in destination folder
745-
leftover_files = [
746-
pathlib.Path(defaced_nii)
747-
for defaced_nii in glob.glob(
748-
f"{final_destination}/**/*_defaced*.nii*", recursive=True
749-
)
750-
]
751-
for leftover in leftover_files:
752-
leftover.unlink()
744+
if not os.getenv("PETDEFACE_DEBUG"):
745+
# clean up any errantly leftover files with glob in destination folder
746+
leftover_files = [
747+
pathlib.Path(defaced_nii)
748+
for defaced_nii in glob.glob(
749+
f"{final_destination}/**/*_defaced*.nii*", recursive=True
750+
)
751+
]
752+
for leftover in leftover_files:
753+
leftover.unlink()
753754

754755
print(f"completed copying dataset to {final_destination}")
755756

@@ -770,7 +771,9 @@ def move_defaced_images(
770771
:param move_files: delete defaced images in "working" directory, e.g. move them to the destination dir instead of copying them there, defaults to False
771772
:type move_files: bool, optional
772773
"""
773-
# update paths in mapping dict
774+
# Create a new mapping with destination paths
775+
dest_mapping = {}
776+
774777
for defaced, raw in mapping_dict.items():
775778
# get common path and replace with final_destination to get new path
776779
common_path = os.path.commonpath([defaced.path, raw.path])
@@ -791,15 +794,13 @@ def move_defaced_images(
791794
]
792795
)
793796
)
794-
mapping_dict[defaced] = new_path
797+
dest_mapping[defaced] = new_path
795798

796799
# copy defaced images to new location
797-
for defaced, raw in mapping_dict.items():
798-
if pathlib.Path(raw).exists() and pathlib.Path(defaced).exists():
799-
shutil.copy(defaced.path, raw)
800-
else:
801-
pathlib.Path(raw).parent.mkdir(parents=True, exist_ok=True)
802-
shutil.copy(defaced.path, raw)
800+
for defaced, dest_path in dest_mapping.items():
801+
if pathlib.Path(defaced).exists():
802+
pathlib.Path(dest_path).parent.mkdir(parents=True, exist_ok=True)
803+
shutil.copy(defaced.path, dest_path)
803804

804805
if move_files:
805806
os.remove(defaced.path)

petdeface/qa.py

Lines changed: 138 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,73 @@
2222
from pathlib import Path
2323

2424

25-
def generate_simple_before_and_after(subjects: dict, output_dir):
25+
def preprocess_single_subject(s, output_dir):
26+
"""Preprocess a single subject's images (for parallel processing)."""
27+
temp_dir = os.path.join(output_dir, "temp_3d_images")
28+
29+
# Preprocess original image
30+
orig_result = load_and_preprocess_image(s["orig_path"])
31+
if isinstance(orig_result, nib.Nifti1Image):
32+
# Need to save the averaged image
33+
orig_3d_path = os.path.join(temp_dir, f"orig_{s['id']}.nii.gz")
34+
nib.save(orig_result, orig_3d_path)
35+
orig_img = orig_result
36+
else:
37+
# Already 3D, use original path and load image
38+
orig_3d_path = orig_result
39+
orig_img = nib.load(orig_result)
40+
41+
# Preprocess defaced image
42+
defaced_result = load_and_preprocess_image(s["defaced_path"])
43+
if isinstance(defaced_result, nib.Nifti1Image):
44+
# Need to save the averaged image
45+
defaced_3d_path = os.path.join(temp_dir, f"defaced_{s['id']}.nii.gz")
46+
nib.save(defaced_result, defaced_3d_path)
47+
defaced_img = defaced_result
48+
else:
49+
# Already 3D, use original path and load image
50+
defaced_3d_path = defaced_result
51+
defaced_img = nib.load(defaced_result)
52+
53+
# Create new subject dict with preprocessed paths (update paths to 3D versions)
54+
preprocessed_subject = {
55+
"id": s["id"],
56+
"orig_path": orig_3d_path, # Update to 3D path
57+
"defaced_path": defaced_3d_path, # Update to 3D path
58+
"orig_img": orig_img, # Keep loaded image for direct use
59+
"defaced_img": defaced_img, # Keep loaded image for direct use
60+
}
61+
62+
print(f" Preprocessed {s['id']}")
63+
return preprocessed_subject
64+
65+
66+
def preprocess_images(subjects: dict, output_dir, n_jobs=None):
67+
"""Preprocess all images once: load and convert 4D to 3D if needed."""
68+
print("Preprocessing images (4D→3D conversion)...")
69+
70+
# Create temp directory
71+
temp_dir = os.path.join(output_dir, "temp_3d_images")
72+
os.makedirs(temp_dir, exist_ok=True)
73+
74+
# Set number of jobs for parallel processing
75+
if n_jobs is None:
76+
n_jobs = mp.cpu_count()
77+
print(f"Using {n_jobs} parallel processes for preprocessing")
78+
79+
# Process subjects in parallel
80+
with mp.Pool(processes=n_jobs) as pool:
81+
# Create a partial function with the output_dir fixed
82+
preprocess_func = partial(preprocess_single_subject, output_dir=output_dir)
83+
84+
# Process all subjects in parallel
85+
preprocessed_subjects = pool.map(preprocess_func, subjects)
86+
87+
print(f"Preprocessed {len(preprocessed_subjects)} subjects")
88+
return preprocessed_subjects
89+
90+
91+
def generate_simple_before_and_after(preprocessed_subjects: dict, output_dir):
2692
if not output_dir:
2793
output_dir = TemporaryDirectory()
2894
wf = Workflow(
@@ -32,12 +98,24 @@ def generate_simple_before_and_after(subjects: dict, output_dir):
3298
# Create a list to store all nodes
3399
nodes = []
34100

35-
for s in subjects:
101+
for s in preprocessed_subjects:
36102
# only run this on the T1w images for now
37103
if "T1w" in s["orig_path"]:
38104
o_path = Path(s["orig_path"])
39-
# Create a valid node name by replacing invalid characters
40-
valid_name = f"before_after_{s['id'].replace('-', '_').replace('_', '')}"
105+
# Create a valid node name by replacing invalid characters but preserving session info
106+
# Use the full path to ensure uniqueness
107+
path_parts = s["orig_path"].split(os.sep)
108+
subject_part = next(
109+
(p for p in path_parts if p.startswith("sub-")), s["id"]
110+
)
111+
session_part = next((p for p in path_parts if p.startswith("ses-")), "")
112+
113+
if session_part:
114+
valid_name = f"before_after_{subject_part}_{session_part}".replace(
115+
"-", "_"
116+
)
117+
else:
118+
valid_name = f"before_after_{subject_part}".replace("-", "_")
41119
node = Node(
42120
SimpleBeforeAfterRPT(
43121
before=s["orig_path"],
@@ -52,10 +130,14 @@ def generate_simple_before_and_after(subjects: dict, output_dir):
52130

53131
# Add all nodes to the workflow
54132
wf.add_nodes(nodes)
55-
wf.run(plugin="MultiProc", plugin_args={"n_procs": mp.cpu_count()})
56133

57-
# Collect SVG files and move them to images folder
58-
collect_svg_reports(wf, output_dir)
134+
# Only run if we have nodes to process
135+
if nodes:
136+
wf.run(plugin="MultiProc", plugin_args={"n_procs": mp.cpu_count()})
137+
# Collect SVG files and move them to images folder
138+
collect_svg_reports(wf, output_dir)
139+
else:
140+
print("No T1w images found for SVG report generation")
59141

60142

61143
def collect_svg_reports(wf, output_dir):
@@ -343,7 +425,8 @@ def create_overlay_gif(image_files, subject_id, output_dir):
343425

344426

345427
def load_and_preprocess_image(img_path):
346-
"""Load image and take mean if it has more than 3 dimensions."""
428+
"""Load image and take mean if it has more than 3 dimensions.
429+
Returns nibabel image if averaging was needed, otherwise returns original path."""
347430
img = nib.load(img_path)
348431

349432
# Check if image has more than 3 dimensions
@@ -356,13 +439,14 @@ def load_and_preprocess_image(img_path):
356439
mean_data = np.mean(data, axis=3)
357440
# Create new 3D image
358441
img = nib.Nifti1Image(mean_data, img.affine, img.header)
359-
360-
return img
442+
return img # Return nibabel image object
443+
else:
444+
return img_path # Return original path if already 3D
361445

362446

363447
def create_comparison_html(
364-
orig_path,
365-
defaced_path,
448+
orig_img,
449+
defaced_img,
366450
subject_id,
367451
output_dir,
368452
display_mode="side-by-side",
@@ -371,21 +455,15 @@ def create_comparison_html(
371455
"""Create HTML comparison page for a subject using nilearn ortho views."""
372456

373457
# Get basenames for display
374-
orig_basename = os.path.basename(orig_path)
375-
defaced_basename = os.path.basename(defaced_path)
458+
orig_basename = f"orig_{subject_id}.nii.gz"
459+
defaced_basename = f"defaced_{subject_id}.nii.gz"
376460

377461
# Generate images and get their filenames
378462
image_files = []
379-
for label, img_path, cmap in [
380-
("original", orig_path, "hot"), # Colored for original
381-
("defaced", defaced_path, "gray"), # Grey for defaced
463+
for label, img, basename, cmap in [
464+
("original", orig_img, orig_basename, "hot"), # Colored for original
465+
("defaced", defaced_img, defaced_basename, "gray"), # Grey for defaced
382466
]:
383-
# Get the basename for display
384-
basename = os.path.basename(img_path)
385-
386-
# Load and preprocess image (handle 4D if needed)
387-
img = load_and_preprocess_image(img_path)
388-
389467
# save image to temp folder for later loading
390468

391469
# Create single sagittal slice using matplotlib directly
@@ -659,25 +737,14 @@ def process_subject(subject, output_dir, size="compact"):
659737
"""Process a single subject (for parallel processing)."""
660738
print(f"Processing {subject['id']}...")
661739
try:
662-
subject_temp_dir = tempfile.TemporaryDirectory()
663-
# load each image file then save it to temp
664-
original_image = load_and_preprocess_image(subject["orig_path"])
665-
defaced_image = load_and_preprocess_image(subject["defaced_path"])
666-
original_image = nib.Nifti1Image(original_image.get_fdata(), original_image.affine, original_image.header)
667-
defaced_image = nib.Nifti1Image(defaced_image.get_fdata(), defaced_image.affine, defaced_image.header)
668-
669-
nib.save(original_image, Path(subject_temp_dir.name) / Path(subject["orig_path"]).name)
670-
nib.save(defaced_image, Path(subject_temp_dir.name) / Path(subject["defaced_path"]).name)
671-
672740
comparison_file = create_comparison_html(
673-
Path(subject_temp_dir.name) / Path(subject["orig_path"]).name,
674-
Path(subject_temp_dir.name) / Path(subject["defaced_path"]).name,
741+
subject["orig_img"],
742+
subject["defaced_img"],
675743
subject["id"],
676744
output_dir,
677745
"side-by-side", # Always generate side-by-side for individual pages
678746
size,
679747
)
680-
681748
print(f" Completed: {subject['id']}")
682749
return comparison_file
683750
except Exception as e:
@@ -687,8 +754,22 @@ def process_subject(subject, output_dir, size="compact"):
687754

688755
def build_subjects_from_datasets(orig_dir, defaced_dir):
689756
"""Build subject list with file paths."""
690-
orig_files = glob(os.path.join(orig_dir, "**", "*.nii*"), recursive=True)
691-
defaced_files = glob(os.path.join(defaced_dir, "**", "*.nii*"), recursive=True)
757+
758+
# Get all NIfTI files but exclude derivatives and workflow folders
759+
def get_nifti_files(directory):
760+
all_files = glob(os.path.join(directory, "**", "*.nii*"), recursive=True)
761+
# Filter out files in derivatives, workflow, or other processing folders
762+
filtered_files = []
763+
for file_path in all_files:
764+
# Skip files in derivatives, workflow, or processing-related directories
765+
path_parts = file_path.split(os.sep)
766+
skip_dirs = ["derivatives", "work", "wf", "tmp", "temp", "scratch", "cache"]
767+
if not any(skip_dir in path_parts for skip_dir in skip_dirs):
768+
filtered_files.append(file_path)
769+
return filtered_files
770+
771+
orig_files = get_nifti_files(orig_dir)
772+
defaced_files = get_nifti_files(defaced_dir)
692773

693774
def strip_ext(path):
694775
base = os.path.basename(path)
@@ -751,8 +832,8 @@ def create_side_by_side_index_html(subjects, output_dir, size="compact"):
751832
comparisons_html = ""
752833
for subject in subjects:
753834
subject_id = subject["id"]
754-
orig_basename = os.path.basename(subject["orig_path"])
755-
defaced_basename = os.path.basename(subject["defaced_path"])
835+
orig_basename = f"orig_{subject_id}.nii.gz"
836+
defaced_basename = f"defaced_{subject_id}.nii.gz"
756837

757838
# Check if the PNG files exist
758839
orig_png = f"images/original_{subject_id}.png"
@@ -1157,13 +1238,18 @@ def main():
11571238
print(f" - {s['id']}")
11581239
exit(1)
11591240

1160-
# create nireports svg's for comparison
1161-
generate_simple_before_and_after(subjects=subjects, output_dir=output_dir)
1162-
11631241
# Set number of jobs for parallel processing
11641242
n_jobs = args.n_jobs if args.n_jobs else mp.cpu_count()
11651243
print(f"Using {n_jobs} parallel processes")
11661244

1245+
# Preprocess all images once (4D→3D conversion)
1246+
preprocessed_subjects = preprocess_images(subjects, output_dir, n_jobs)
1247+
1248+
# create nireports svg's for comparison
1249+
generate_simple_before_and_after(
1250+
preprocessed_subjects=preprocessed_subjects, output_dir=output_dir
1251+
)
1252+
11671253
# Process subjects in parallel
11681254
print("Generating comparisons...")
11691255
with mp.Pool(processes=n_jobs) as pool:
@@ -1175,15 +1261,19 @@ def main():
11751261
)
11761262

11771263
# Process all subjects in parallel
1178-
results = pool.map(process_func, subjects)
1264+
results = pool.map(process_func, preprocessed_subjects)
11791265

11801266
# Count successful results
11811267
successful = [r for r in results if r is not None]
1182-
print(f"Successfully processed {len(successful)} out of {len(subjects)} subjects")
1268+
print(
1269+
f"Successfully processed {len(successful)} out of {len(preprocessed_subjects)} subjects"
1270+
)
11831271

11841272
# Create both HTML files
1185-
side_by_side_file = create_side_by_side_index_html(subjects, output_dir, args.size)
1186-
animated_file = create_gif_index_html(subjects, output_dir, args.size)
1273+
side_by_side_file = create_side_by_side_index_html(
1274+
preprocessed_subjects, output_dir, args.size
1275+
)
1276+
animated_file = create_gif_index_html(preprocessed_subjects, output_dir, args.size)
11871277

11881278
# Create a simple index that links to both
11891279
index_html = f"""
@@ -1241,7 +1331,7 @@ def main():
12411331
<a href="SimpleBeforeAfterRPT.html" class="link-button">SVG Reports View</a>
12421332
12431333
<p style="margin-top: 30px; color: #999; font-size: 14px;">
1244-
Generated with {len(subjects)} subjects
1334+
Generated with {len(preprocessed_subjects)} subjects
12451335
</p>
12461336
</div>
12471337
</body>

0 commit comments

Comments
 (0)