@@ -419,6 +419,7 @@ def main(
419
419
suffix = "radimagenet_resnet50"
420
420
else :
421
421
import torchvision
422
+
422
423
feature_network = torchvision .models .squeezenet1_1 (pretrained = True )
423
424
suffix = "squeezenet1_1"
424
425
@@ -529,10 +530,7 @@ def main(
529
530
center_slices_ratio = center_slices_ratio_final ,
530
531
xy_only = False ,
531
532
)
532
- logger .info (
533
- f"feats shapes: { feats [0 ].shape } , "
534
- f"{ feats [1 ].shape } , { feats [2 ].shape } "
535
- )
533
+ logger .info (f"feats shapes: { feats [0 ].shape } , " f"{ feats [1 ].shape } , { feats [2 ].shape } " )
536
534
torch .save (feats , out_fp )
537
535
538
536
real_features_xy .append (feats [0 ])
@@ -543,8 +541,7 @@ def main(
543
541
real_features_yz = torch .vstack (real_features_yz )
544
542
real_features_zx = torch .vstack (real_features_zx )
545
543
logger .info (
546
- f"Real feature shapes: { real_features_xy .shape } , "
547
- f"{ real_features_yz .shape } , { real_features_zx .shape } "
544
+ f"Real feature shapes: { real_features_xy .shape } , " f"{ real_features_yz .shape } , { real_features_zx .shape } "
548
545
)
549
546
550
547
# -------------------------------------------------------------------------
@@ -573,10 +570,7 @@ def main(
573
570
center_slices_ratio = center_slices_ratio_final ,
574
571
xy_only = False ,
575
572
)
576
- logger .info (
577
- f"feats shapes: { feats [0 ].shape } , "
578
- f"{ feats [1 ].shape } , { feats [2 ].shape } "
579
- )
573
+ logger .info (f"feats shapes: { feats [0 ].shape } , " f"{ feats [1 ].shape } , { feats [2 ].shape } " )
580
574
torch .save (feats , out_fp )
581
575
582
576
synth_features_xy .append (feats [0 ])
@@ -587,8 +581,7 @@ def main(
587
581
synth_features_yz = torch .vstack (synth_features_yz )
588
582
synth_features_zx = torch .vstack (synth_features_zx )
589
583
logger .info (
590
- f"Synthetic feature shapes: { synth_features_xy .shape } , "
591
- f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
584
+ f"Synthetic feature shapes: { synth_features_xy .shape } , " f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
592
585
)
593
586
594
587
# -------------------------------------------------------------------------
@@ -640,12 +633,8 @@ def main(
640
633
synth_yz = torch .vstack (all_tensors_list [4 ])
641
634
synth_zx = torch .vstack (all_tensors_list [5 ])
642
635
643
- logger .info (
644
- f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } "
645
- )
646
- logger .info (
647
- f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } "
648
- )
636
+ logger .info (f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } " )
637
+ logger .info (f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } " )
649
638
650
639
fid = FIDMetric ()
651
640
logger .info (f"Computing FID for: { output_root0 } | { output_root1 } " )
0 commit comments