Skip to content

Commit 52a04fc

Browse files
committed
clean
Signed-off-by: dongyang0122 <[email protected]>
1 parent 466e9c6 commit 52a04fc

File tree

1 file changed

+56
-63
lines changed

1 file changed

+56
-63
lines changed

generation/maisi/scripts/compute_fid2p5d_ct.py renamed to generation/maisi/scripts/compute_fid2-5d_ct.py

+56-63
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# Copyright (c) MONAI Consortium
32
# Licensed under the Apache License, Version 2.0 (the "License");
43
# you may not use this file except in compliance with the License.
@@ -13,7 +12,7 @@
1312
# and limitations under the License.
1413

1514
"""
16-
Compute 2.5D FID using distributed GPU processing, **without** external fid_utils dependencies.
15+
Compute 2.5D FID using distributed GPU processing.
1716
1817
SHELL Usage Example:
1918
-------------------
@@ -22,7 +21,7 @@
2221
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
2322
NUM_GPUS=7
2423
25-
torchrun --nproc_per_node=${NUM_GPUS} compute_fid2p5d_ct.py \
24+
torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \
2625
--model_name "radimagenet_resnet50" \
2726
--data0_dataroot "path/to/datasetA" \
2827
--data0_filelist "path/to/filelistA.txt" \
@@ -82,10 +81,16 @@
8281
from monai.metrics.fid import FIDMetric
8382
from monai.transforms import Compose
8483

84+
import logging
85+
8586
# ------------------------------------------------------------------------------
86-
# Below are the core utilities originally in fid_utils.py, now inlined here
87-
# to remove external dependency.
87+
# Create logger
8888
# ------------------------------------------------------------------------------
89+
logger = logging.getLogger("fid_2-5d_ct")
90+
if not logger.handlers:
91+
# Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
92+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
93+
logger.setLevel(logging.INFO)
8994

9095

9196
def drop_empty_slice(slices, empty_threshold: float):
@@ -111,7 +116,7 @@ def drop_empty_slice(slices, empty_threshold: float):
111116
else:
112117
outputs.append(True)
113118

114-
print(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
119+
logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
115120
return outputs
116121

117122

@@ -183,7 +188,7 @@ def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = Fal
183188
volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
184189
norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
185190
"""
186-
print(f"norm2d: {norm2d}")
191+
logger.info(f"norm2d: {norm2d}")
187192
dim = len(volume.shape)
188193
# If norm2d is True, only meaningful for 4D data (B, C, H, W):
189194
if dim == 4 and norm2d:
@@ -236,20 +241,18 @@ def get_features_2p5d(
236241
Returns:
237242
tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
238243
"""
239-
print(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")
244+
logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")
240245

241246
# If there's only 1 channel, replicate to 3 channels
242247
if image.shape[1] == 1:
243248
image = image.repeat(1, 3, 1, 1, 1)
244249

245-
# Convert from 'RGB'→(R,G,B) to (B, G, R) ordering
250+
# Convert from 'RGB'→(R,G,B) to (B,G,R)
246251
image = image[:, [2, 1, 0], ...]
247252

248253
B, C, H, W, D = image.size()
249254
with torch.no_grad():
250-
# ---------------------------------------------------------------------
251-
# 1) XY-plane slicing along D
252-
# ---------------------------------------------------------------------
255+
# ---------------------- XY-plane slicing along D ----------------------
253256
if center_slices:
254257
start_d = int((1.0 - center_slices_ratio) / 2.0 * D)
255258
end_d = int((1.0 + center_slices_ratio) / 2.0 * D)
@@ -268,13 +271,10 @@ def get_features_2p5d(
268271

269272
feature_image_xy = feature_network.forward(images_2d)
270273
feature_image_xy = spatial_average(feature_image_xy, keepdim=False)
271-
272274
if xy_only:
273275
return feature_image_xy, None, None
274276

275-
# ---------------------------------------------------------------------
276-
# 2) YZ-plane slicing along H
277-
# ---------------------------------------------------------------------
277+
# ---------------------- YZ-plane slicing along H ----------------------
278278
if center_slices:
279279
start_h = int((1.0 - center_slices_ratio) / 2.0 * H)
280280
end_h = int((1.0 + center_slices_ratio) / 2.0 * H)
@@ -294,9 +294,7 @@ def get_features_2p5d(
294294
feature_image_yz = feature_network.forward(images_2d)
295295
feature_image_yz = spatial_average(feature_image_yz, keepdim=False)
296296

297-
# ---------------------------------------------------------------------
298-
# 3) ZX-plane slicing along W
299-
# ---------------------------------------------------------------------
297+
# ---------------------- ZX-plane slicing along W ----------------------
300298
if center_slices:
301299
start_w = int((1.0 - center_slices_ratio) / 2.0 * W)
302300
end_w = int((1.0 + center_slices_ratio) / 2.0 * W)
@@ -319,11 +317,6 @@ def get_features_2p5d(
319317
return feature_image_xy, feature_image_yz, feature_image_zx
320318

321319

322-
# ------------------------------------------------------------------------------
323-
# End inline fid_utils code
324-
# ------------------------------------------------------------------------------
325-
326-
327320
def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor:
328321
"""
329322
Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
@@ -336,7 +329,6 @@ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float =
336329
Returns:
337330
torch.Tensor: Padded tensor matching `max_size` along dim=0.
338331
"""
339-
# For a shape (B, C, ...), we only pad the B dimension up to `max_size`.
340332
pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]]
341333
return F.pad(tensor, pad_size, "constant", padding_value)
342334

@@ -395,11 +387,9 @@ def main(
395387
world_size = int(dist.get_world_size())
396388
device = torch.device("cuda", local_rank)
397389
torch.cuda.set_device(device)
398-
print(f"[INFO] Running process on {device} of total {world_size} ranks.")
390+
logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.")
399391

400-
# -------------------------------------------------------------------------
401392
# Convert potential string bools to actual bools (Fire sometimes passes strings)
402-
# -------------------------------------------------------------------------
403393
if not isinstance(enable_center_slices, bool):
404394
enable_center_slices = enable_center_slices.lower() == "true"
405395
if not isinstance(enable_padding, bool):
@@ -413,46 +403,44 @@ def main(
413403

414404
# Print out some flags on rank 0
415405
if local_rank == 0:
416-
print(f"[INFO] enable_center_slices: {enable_center_slices}")
417-
print(f"[INFO] enable_padding: {enable_padding}")
418-
print(f"[INFO] enable_center_cropping: {enable_center_cropping}")
419-
print(f"[INFO] enable_resampling: {enable_resampling}")
420-
print(f"[INFO] ignore_existing: {ignore_existing}")
406+
logger.info(f"enable_center_slices: {enable_center_slices}")
407+
logger.info(f"enable_padding: {enable_padding}")
408+
logger.info(f"enable_center_cropping: {enable_center_cropping}")
409+
logger.info(f"enable_resampling: {enable_resampling}")
410+
logger.info(f"ignore_existing: {ignore_existing}")
421411

422412
# -------------------------------------------------------------------------
423413
# Load feature extraction model
424414
# -------------------------------------------------------------------------
425415
if model_name == "radimagenet_resnet50":
426-
# Using a model from Warvito/radimagenet-models on Torch Hub
427416
feature_network = torch.hub.load(
428417
"Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True
429418
)
430419
suffix = "radimagenet_resnet50"
431420
else:
432421
import torchvision
433-
434422
feature_network = torchvision.models.squeezenet1_1(pretrained=True)
435423
suffix = "squeezenet1_1"
436424

437425
feature_network.to(device)
438426
feature_network.eval()
439427

440428
# -------------------------------------------------------------------------
441-
# Parse shape/spacings from string
429+
# Parse shape/spacings
442430
# -------------------------------------------------------------------------
443431
t_shape = [int(x) for x in target_shape.split("x")]
444432
target_shape_tuple = tuple(t_shape)
445433
if enable_resampling:
446434
rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")]
447435
rs_spacing_tuple = tuple(rs_spacing)
448436
if local_rank == 0:
449-
print(f"[INFO] resampling spacing: {rs_spacing_tuple}")
437+
logger.info(f"resampling spacing: {rs_spacing_tuple}")
450438
else:
451439
rs_spacing_tuple = (1.0, 1.0, 1.0)
452440

453441
center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
454442
if local_rank == 0:
455-
print(f"[INFO] center_slices_ratio: {center_slices_ratio_final}")
443+
logger.info(f"center_slices_ratio: {center_slices_ratio_final}")
456444

457445
# -------------------------------------------------------------------------
458446
# Prepare dataset 0
@@ -490,25 +478,20 @@ def main(
490478
monai.transforms.EnsureChannelFirstd(keys=["image"]),
491479
monai.transforms.Orientationd(keys=["image"], axcodes="RAS"),
492480
]
493-
494481
if enable_resampling:
495482
transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]))
496483
if enable_padding:
497484
transform_list.append(
498-
monai.transforms.SpatialPadd(
499-
keys=["image"], spatial_size=target_shape_tuple, mode=["constant"], value=-1000
500-
)
485+
monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000)
501486
)
502487
if enable_center_cropping:
503488
transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple))
504489

505-
# Intensity scaling to clamp between [-1000, 1000]
506490
transform_list.append(
507491
monai.transforms.ScaleIntensityRanged(
508492
keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True
509493
)
510494
)
511-
512495
transforms = Compose(transform_list)
513496

514497
# -------------------------------------------------------------------------
@@ -527,7 +510,7 @@ def main(
527510
for idx, batch_data in enumerate(real_loader, start=1):
528511
img = batch_data["image"].to(device)
529512
fn = img.meta["filename_or_obj"][0]
530-
print(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}")
513+
logger.info(f"[Rank {local_rank}] Real data {idx}/{len(filenames0)}: {fn}")
531514

532515
out_fp = fn.replace(data0_dataroot, output_root0).replace(".nii.gz", ".pt")
533516
out_fp = Path(out_fp)
@@ -537,17 +520,19 @@ def main(
537520
feats = torch.load(out_fp)
538521
else:
539522
img_t = img.as_tensor()
540-
print(f"[INFO] image shape: {tuple(img_t.shape)}")
523+
logger.info(f"image shape: {tuple(img_t.shape)}")
541524

542-
# Inline get_features_2p5d
543525
feats = get_features_2p5d(
544526
img_t,
545527
feature_network,
546528
center_slices=enable_center_slices,
547529
center_slices_ratio=center_slices_ratio_final,
548530
xy_only=False,
549531
)
550-
print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
532+
logger.info(
533+
f"feats shapes: {feats[0].shape}, "
534+
f"{feats[1].shape}, {feats[2].shape}"
535+
)
551536
torch.save(feats, out_fp)
552537

553538
real_features_xy.append(feats[0])
@@ -557,7 +542,10 @@ def main(
557542
real_features_xy = torch.vstack(real_features_xy)
558543
real_features_yz = torch.vstack(real_features_yz)
559544
real_features_zx = torch.vstack(real_features_zx)
560-
print(f"[INFO] Real feature shapes: {real_features_xy.shape}, {real_features_yz.shape}, {real_features_zx.shape}")
545+
logger.info(
546+
f"Real feature shapes: {real_features_xy.shape}, "
547+
f"{real_features_yz.shape}, {real_features_zx.shape}"
548+
)
561549

562550
# -------------------------------------------------------------------------
563551
# Extract features for dataset 1
@@ -566,7 +554,7 @@ def main(
566554
for idx, batch_data in enumerate(synt_loader, start=1):
567555
img = batch_data["image"].to(device)
568556
fn = img.meta["filename_or_obj"][0]
569-
print(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}")
557+
logger.info(f"[Rank {local_rank}] Synthetic data {idx}/{len(filenames1)}: {fn}")
570558

571559
out_fp = fn.replace(data1_dataroot, output_root1).replace(".nii.gz", ".pt")
572560
out_fp = Path(out_fp)
@@ -576,7 +564,7 @@ def main(
576564
feats = torch.load(out_fp)
577565
else:
578566
img_t = img.as_tensor()
579-
print(f"[INFO] image shape: {tuple(img_t.shape)}")
567+
logger.info(f"image shape: {tuple(img_t.shape)}")
580568

581569
feats = get_features_2p5d(
582570
img_t,
@@ -585,7 +573,10 @@ def main(
585573
center_slices_ratio=center_slices_ratio_final,
586574
xy_only=False,
587575
)
588-
print(f"[INFO] feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
576+
logger.info(
577+
f"feats shapes: {feats[0].shape}, "
578+
f"{feats[1].shape}, {feats[2].shape}"
579+
)
589580
torch.save(feats, out_fp)
590581

591582
synth_features_xy.append(feats[0])
@@ -595,8 +586,8 @@ def main(
595586
synth_features_xy = torch.vstack(synth_features_xy)
596587
synth_features_yz = torch.vstack(synth_features_yz)
597588
synth_features_zx = torch.vstack(synth_features_zx)
598-
print(
599-
f"[INFO] Synthetic feature shapes: {synth_features_xy.shape}, "
589+
logger.info(
590+
f"Synthetic feature shapes: {synth_features_xy.shape}, "
600591
f"{synth_features_yz.shape}, {synth_features_zx.shape}"
601592
)
602593

@@ -649,25 +640,27 @@ def main(
649640
synth_yz = torch.vstack(all_tensors_list[4])
650641
synth_zx = torch.vstack(all_tensors_list[5])
651642

652-
print(f"[INFO] Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}")
653-
print(f"[INFO] Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}")
643+
logger.info(
644+
f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}"
645+
)
646+
logger.info(
647+
f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}"
648+
)
654649

655650
fid = FIDMetric()
656-
print(f"\n[INFO] Computing FID for: {output_root0} | {output_root1}")
651+
logger.info(f"Computing FID for: {output_root0} | {output_root1}")
657652
fid_res_xy = fid(synth_xy, real_xy)
658653
fid_res_yz = fid(synth_yz, real_yz)
659654
fid_res_zx = fid(synth_zx, real_zx)
660655

661-
print(f" FID XY: {fid_res_xy}")
662-
print(f" FID YZ: {fid_res_yz}")
663-
print(f" FID ZX: {fid_res_zx}")
656+
logger.info(f"FID XY: {fid_res_xy}")
657+
logger.info(f"FID YZ: {fid_res_yz}")
658+
logger.info(f"FID ZX: {fid_res_zx}")
664659
fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0
665-
print(f" FID Avg: {fid_avg}")
660+
logger.info(f"FID Avg: {fid_avg}")
666661

667662
dist.destroy_process_group()
668663

669664

670665
if __name__ == "__main__":
671-
# Using python-fire for command-line interface.
672-
# e.g., python compute_fid2d_mgpu.py --model_name=radimagenet_resnet50 --num_images=100 ...
673666
fire.Fire(main)

0 commit comments

Comments
 (0)