1
- #!/usr/bin/env python
2
1
# Copyright (c) MONAI Consortium
3
2
# Licensed under the Apache License, Version 2.0 (the "License");
4
3
# you may not use this file except in compliance with the License.
13
12
# and limitations under the License.
14
13
15
14
"""
16
- Compute 2.5D FID using distributed GPU processing, **without** external fid_utils dependencies .
15
+ Compute 2.5D FID using distributed GPU processing.
17
16
18
17
SHELL Usage Example:
19
18
-------------------
22
21
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
23
22
NUM_GPUS=7
24
23
25
- torchrun --nproc_per_node=${NUM_GPUS} compute_fid2p5d_ct .py \
24
+ torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct .py \
26
25
--model_name "radimagenet_resnet50" \
27
26
--data0_dataroot "path/to/datasetA" \
28
27
--data0_filelist "path/to/filelistA.txt" \
82
81
from monai .metrics .fid import FIDMetric
83
82
from monai .transforms import Compose
84
83
84
+ import logging
85
+
85
86
# ------------------------------------------------------------------------------
86
- # Below are the core utilities originally in fid_utils.py, now inlined here
87
- # to remove external dependency.
87
+ # Create logger
88
88
# ------------------------------------------------------------------------------
89
+ logger = logging .getLogger ("fid_2-5d_ct" )
90
+ if not logger .handlers :
91
+ # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
92
+ logging .basicConfig (stream = sys .stdout , level = logging .INFO )
93
+ logger .setLevel (logging .INFO )
89
94
90
95
91
96
def drop_empty_slice (slices , empty_threshold : float ):
@@ -111,7 +116,7 @@ def drop_empty_slice(slices, empty_threshold: float):
111
116
else :
112
117
outputs .append (True )
113
118
114
- print (f"Empty slice drop rate { round ((n_drop / len (slices ))* 100 ,1 )} %" )
119
+ logger . info (f"Empty slice drop rate { round ((n_drop / len (slices ))* 100 ,1 )} %" )
115
120
return outputs
116
121
117
122
@@ -183,7 +188,7 @@ def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = Fal
183
188
volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
184
189
norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
185
190
"""
186
- print (f"norm2d: { norm2d } " )
191
+ logger . info (f"norm2d: { norm2d } " )
187
192
dim = len (volume .shape )
188
193
# If norm2d is True, only meaningful for 4D data (B, C, H, W):
189
194
if dim == 4 and norm2d :
@@ -236,20 +241,18 @@ def get_features_2p5d(
236
241
Returns:
237
242
tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
238
243
"""
239
- print (f"center_slices: { center_slices } , ratio: { center_slices_ratio } " )
244
+ logger . info (f"center_slices: { center_slices } , ratio: { center_slices_ratio } " )
240
245
241
246
# If there's only 1 channel, replicate to 3 channels
242
247
if image .shape [1 ] == 1 :
243
248
image = image .repeat (1 , 3 , 1 , 1 , 1 )
244
249
245
- # Convert from 'RGB'→(R,G,B) to (B, G, R) ordering
250
+ # Convert from 'RGB'→(R,G,B) to (B,G,R)
246
251
image = image [:, [2 , 1 , 0 ], ...]
247
252
248
253
B , C , H , W , D = image .size ()
249
254
with torch .no_grad ():
250
- # ---------------------------------------------------------------------
251
- # 1) XY-plane slicing along D
252
- # ---------------------------------------------------------------------
255
+ # ---------------------- XY-plane slicing along D ----------------------
253
256
if center_slices :
254
257
start_d = int ((1.0 - center_slices_ratio ) / 2.0 * D )
255
258
end_d = int ((1.0 + center_slices_ratio ) / 2.0 * D )
@@ -268,13 +271,10 @@ def get_features_2p5d(
268
271
269
272
feature_image_xy = feature_network .forward (images_2d )
270
273
feature_image_xy = spatial_average (feature_image_xy , keepdim = False )
271
-
272
274
if xy_only :
273
275
return feature_image_xy , None , None
274
276
275
- # ---------------------------------------------------------------------
276
- # 2) YZ-plane slicing along H
277
- # ---------------------------------------------------------------------
277
+ # ---------------------- YZ-plane slicing along H ----------------------
278
278
if center_slices :
279
279
start_h = int ((1.0 - center_slices_ratio ) / 2.0 * H )
280
280
end_h = int ((1.0 + center_slices_ratio ) / 2.0 * H )
@@ -294,9 +294,7 @@ def get_features_2p5d(
294
294
feature_image_yz = feature_network .forward (images_2d )
295
295
feature_image_yz = spatial_average (feature_image_yz , keepdim = False )
296
296
297
- # ---------------------------------------------------------------------
298
- # 3) ZX-plane slicing along W
299
- # ---------------------------------------------------------------------
297
+ # ---------------------- ZX-plane slicing along W ----------------------
300
298
if center_slices :
301
299
start_w = int ((1.0 - center_slices_ratio ) / 2.0 * W )
302
300
end_w = int ((1.0 + center_slices_ratio ) / 2.0 * W )
@@ -319,11 +317,6 @@ def get_features_2p5d(
319
317
return feature_image_xy , feature_image_yz , feature_image_zx
320
318
321
319
322
- # ------------------------------------------------------------------------------
323
- # End inline fid_utils code
324
- # ------------------------------------------------------------------------------
325
-
326
-
327
320
def pad_to_max_size (tensor : torch .Tensor , max_size : int , padding_value : float = 0.0 ) -> torch .Tensor :
328
321
"""
329
322
Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
@@ -336,7 +329,6 @@ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float =
336
329
Returns:
337
330
torch.Tensor: Padded tensor matching `max_size` along dim=0.
338
331
"""
339
- # For a shape (B, C, ...), we only pad the B dimension up to `max_size`.
340
332
pad_size = [0 , 0 ] * (len (tensor .shape ) - 1 ) + [0 , max_size - tensor .shape [0 ]]
341
333
return F .pad (tensor , pad_size , "constant" , padding_value )
342
334
@@ -395,11 +387,9 @@ def main(
395
387
world_size = int (dist .get_world_size ())
396
388
device = torch .device ("cuda" , local_rank )
397
389
torch .cuda .set_device (device )
398
- print (f"[INFO] Running process on { device } of total { world_size } ranks." )
390
+ logger . info (f"[INFO] Running process on { device } of total { world_size } ranks." )
399
391
400
- # -------------------------------------------------------------------------
401
392
# Convert potential string bools to actual bools (Fire sometimes passes strings)
402
- # -------------------------------------------------------------------------
403
393
if not isinstance (enable_center_slices , bool ):
404
394
enable_center_slices = enable_center_slices .lower () == "true"
405
395
if not isinstance (enable_padding , bool ):
@@ -413,46 +403,44 @@ def main(
413
403
414
404
# Print out some flags on rank 0
415
405
if local_rank == 0 :
416
- print (f"[INFO] enable_center_slices: { enable_center_slices } " )
417
- print (f"[INFO] enable_padding: { enable_padding } " )
418
- print (f"[INFO] enable_center_cropping: { enable_center_cropping } " )
419
- print (f"[INFO] enable_resampling: { enable_resampling } " )
420
- print (f"[INFO] ignore_existing: { ignore_existing } " )
406
+ logger . info (f"enable_center_slices: { enable_center_slices } " )
407
+ logger . info (f"enable_padding: { enable_padding } " )
408
+ logger . info (f"enable_center_cropping: { enable_center_cropping } " )
409
+ logger . info (f"enable_resampling: { enable_resampling } " )
410
+ logger . info (f"ignore_existing: { ignore_existing } " )
421
411
422
412
# -------------------------------------------------------------------------
423
413
# Load feature extraction model
424
414
# -------------------------------------------------------------------------
425
415
if model_name == "radimagenet_resnet50" :
426
- # Using a model from Warvito/radimagenet-models on Torch Hub
427
416
feature_network = torch .hub .load (
428
417
"Warvito/radimagenet-models" , model = "radimagenet_resnet50" , verbose = True , trust_repo = True
429
418
)
430
419
suffix = "radimagenet_resnet50"
431
420
else :
432
421
import torchvision
433
-
434
422
feature_network = torchvision .models .squeezenet1_1 (pretrained = True )
435
423
suffix = "squeezenet1_1"
436
424
437
425
feature_network .to (device )
438
426
feature_network .eval ()
439
427
440
428
# -------------------------------------------------------------------------
441
- # Parse shape/spacings from string
429
+ # Parse shape/spacings
442
430
# -------------------------------------------------------------------------
443
431
t_shape = [int (x ) for x in target_shape .split ("x" )]
444
432
target_shape_tuple = tuple (t_shape )
445
433
if enable_resampling :
446
434
rs_spacing = [float (x ) for x in enable_resampling_spacing .split ("x" )]
447
435
rs_spacing_tuple = tuple (rs_spacing )
448
436
if local_rank == 0 :
449
- print (f"[INFO] resampling spacing: { rs_spacing_tuple } " )
437
+ logger . info (f"resampling spacing: { rs_spacing_tuple } " )
450
438
else :
451
439
rs_spacing_tuple = (1.0 , 1.0 , 1.0 )
452
440
453
441
center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
454
442
if local_rank == 0 :
455
- print (f"[INFO] center_slices_ratio: { center_slices_ratio_final } " )
443
+ logger . info (f"center_slices_ratio: { center_slices_ratio_final } " )
456
444
457
445
# -------------------------------------------------------------------------
458
446
# Prepare dataset 0
@@ -490,25 +478,20 @@ def main(
490
478
monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
491
479
monai .transforms .Orientationd (keys = ["image" ], axcodes = "RAS" ),
492
480
]
493
-
494
481
if enable_resampling :
495
482
transform_list .append (monai .transforms .Spacingd (keys = ["image" ], pixdim = rs_spacing_tuple , mode = ["bilinear" ]))
496
483
if enable_padding :
497
484
transform_list .append (
498
- monai .transforms .SpatialPadd (
499
- keys = ["image" ], spatial_size = target_shape_tuple , mode = ["constant" ], value = - 1000
500
- )
485
+ monai .transforms .SpatialPadd (keys = ["image" ], spatial_size = target_shape_tuple , mode = "constant" , value = - 1000 )
501
486
)
502
487
if enable_center_cropping :
503
488
transform_list .append (monai .transforms .CenterSpatialCropd (keys = ["image" ], roi_size = target_shape_tuple ))
504
489
505
- # Intensity scaling to clamp between [-1000, 1000]
506
490
transform_list .append (
507
491
monai .transforms .ScaleIntensityRanged (
508
492
keys = ["image" ], a_min = - 1000 , a_max = 1000 , b_min = - 1000 , b_max = 1000 , clip = True
509
493
)
510
494
)
511
-
512
495
transforms = Compose (transform_list )
513
496
514
497
# -------------------------------------------------------------------------
@@ -527,7 +510,7 @@ def main(
527
510
for idx , batch_data in enumerate (real_loader , start = 1 ):
528
511
img = batch_data ["image" ].to (device )
529
512
fn = img .meta ["filename_or_obj" ][0 ]
530
- print (f"[Rank { local_rank } ] Real data { idx } /{ len (filenames0 )} : { fn } " )
513
+ logger . info (f"[Rank { local_rank } ] Real data { idx } /{ len (filenames0 )} : { fn } " )
531
514
532
515
out_fp = fn .replace (data0_dataroot , output_root0 ).replace (".nii.gz" , ".pt" )
533
516
out_fp = Path (out_fp )
@@ -537,17 +520,19 @@ def main(
537
520
feats = torch .load (out_fp )
538
521
else :
539
522
img_t = img .as_tensor ()
540
- print (f"[INFO] image shape: { tuple (img_t .shape )} " )
523
+ logger . info (f"image shape: { tuple (img_t .shape )} " )
541
524
542
- # Inline get_features_2p5d
543
525
feats = get_features_2p5d (
544
526
img_t ,
545
527
feature_network ,
546
528
center_slices = enable_center_slices ,
547
529
center_slices_ratio = center_slices_ratio_final ,
548
530
xy_only = False ,
549
531
)
550
- print (f"[INFO] feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
532
+ logger .info (
533
+ f"feats shapes: { feats [0 ].shape } , "
534
+ f"{ feats [1 ].shape } , { feats [2 ].shape } "
535
+ )
551
536
torch .save (feats , out_fp )
552
537
553
538
real_features_xy .append (feats [0 ])
@@ -557,7 +542,10 @@ def main(
557
542
real_features_xy = torch .vstack (real_features_xy )
558
543
real_features_yz = torch .vstack (real_features_yz )
559
544
real_features_zx = torch .vstack (real_features_zx )
560
- print (f"[INFO] Real feature shapes: { real_features_xy .shape } , { real_features_yz .shape } , { real_features_zx .shape } " )
545
+ logger .info (
546
+ f"Real feature shapes: { real_features_xy .shape } , "
547
+ f"{ real_features_yz .shape } , { real_features_zx .shape } "
548
+ )
561
549
562
550
# -------------------------------------------------------------------------
563
551
# Extract features for dataset 1
@@ -566,7 +554,7 @@ def main(
566
554
for idx , batch_data in enumerate (synt_loader , start = 1 ):
567
555
img = batch_data ["image" ].to (device )
568
556
fn = img .meta ["filename_or_obj" ][0 ]
569
- print (f"[Rank { local_rank } ] Synthetic data { idx } /{ len (filenames1 )} : { fn } " )
557
+ logger . info (f"[Rank { local_rank } ] Synthetic data { idx } /{ len (filenames1 )} : { fn } " )
570
558
571
559
out_fp = fn .replace (data1_dataroot , output_root1 ).replace (".nii.gz" , ".pt" )
572
560
out_fp = Path (out_fp )
@@ -576,7 +564,7 @@ def main(
576
564
feats = torch .load (out_fp )
577
565
else :
578
566
img_t = img .as_tensor ()
579
- print (f"[INFO] image shape: { tuple (img_t .shape )} " )
567
+ logger . info (f"image shape: { tuple (img_t .shape )} " )
580
568
581
569
feats = get_features_2p5d (
582
570
img_t ,
@@ -585,7 +573,10 @@ def main(
585
573
center_slices_ratio = center_slices_ratio_final ,
586
574
xy_only = False ,
587
575
)
588
- print (f"[INFO] feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
576
+ logger .info (
577
+ f"feats shapes: { feats [0 ].shape } , "
578
+ f"{ feats [1 ].shape } , { feats [2 ].shape } "
579
+ )
589
580
torch .save (feats , out_fp )
590
581
591
582
synth_features_xy .append (feats [0 ])
@@ -595,8 +586,8 @@ def main(
595
586
synth_features_xy = torch .vstack (synth_features_xy )
596
587
synth_features_yz = torch .vstack (synth_features_yz )
597
588
synth_features_zx = torch .vstack (synth_features_zx )
598
- print (
599
- f"[INFO] Synthetic feature shapes: { synth_features_xy .shape } , "
589
+ logger . info (
590
+ f"Synthetic feature shapes: { synth_features_xy .shape } , "
600
591
f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
601
592
)
602
593
@@ -649,25 +640,27 @@ def main(
649
640
synth_yz = torch .vstack (all_tensors_list [4 ])
650
641
synth_zx = torch .vstack (all_tensors_list [5 ])
651
642
652
- print (f"[INFO] Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } " )
653
- print (f"[INFO] Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } " )
643
+ logger .info (
644
+ f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } "
645
+ )
646
+ logger .info (
647
+ f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } "
648
+ )
654
649
655
650
fid = FIDMetric ()
656
- print (f"\n [INFO] Computing FID for: { output_root0 } | { output_root1 } " )
651
+ logger . info (f"Computing FID for: { output_root0 } | { output_root1 } " )
657
652
fid_res_xy = fid (synth_xy , real_xy )
658
653
fid_res_yz = fid (synth_yz , real_yz )
659
654
fid_res_zx = fid (synth_zx , real_zx )
660
655
661
- print (f" FID XY: { fid_res_xy } " )
662
- print (f" FID YZ: { fid_res_yz } " )
663
- print (f" FID ZX: { fid_res_zx } " )
656
+ logger . info (f"FID XY: { fid_res_xy } " )
657
+ logger . info (f"FID YZ: { fid_res_yz } " )
658
+ logger . info (f"FID ZX: { fid_res_zx } " )
664
659
fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx ) / 3.0
665
- print (f" FID Avg: { fid_avg } " )
660
+ logger . info (f"FID Avg: { fid_avg } " )
666
661
667
662
dist .destroy_process_group ()
668
663
669
664
670
665
if __name__ == "__main__" :
671
- # Using python-fire for command-line interface.
672
- # e.g., python compute_fid2d_mgpu.py --model_name=radimagenet_resnet50 --num_images=100 ...
673
666
fire .Fire (main )
0 commit comments