23
23
from monai .networks .blocks .spatialattention import SpatialAttentionBlock
24
24
from monai .networks .nets .autoencoderkl import AEKLResBlock , AutoencoderKL
25
25
from monai .utils .type_conversion import convert_to_tensor
26
+ from monai .utils .deprecate_utils import deprecated_arg
26
27
27
28
# Set up logging configuration
28
29
logger = logging .getLogger (__name__ )
@@ -33,7 +34,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
33
34
torch .cuda .empty_cache ()
34
35
return
35
36
36
-
37
+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
37
38
class MaisiGroupNorm3D (nn .GroupNorm ):
38
39
"""
39
40
Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +44,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
43
44
num_channels: Number of channels for the group norm.
44
45
eps: Epsilon value for numerical stability.
45
46
affine: Whether to use learnable affine parameters, default to `True`.
46
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
47
+ norm_float16: Deprecated argument .
47
48
print_info: Whether to print information, default to `False`.
48
49
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
49
50
"""
@@ -59,14 +60,15 @@ def __init__(
59
60
save_mem : bool = True ,
60
61
):
61
62
super ().__init__ (num_groups , num_channels , eps , affine )
62
- self .norm_float16 = norm_float16
63
63
self .print_info = print_info
64
64
self .save_mem = save_mem
65
65
66
66
def forward (self , input : torch .Tensor ) -> torch .Tensor :
67
67
if self .print_info :
68
68
logger .info (f"MaisiGroupNorm3D with input size: { input .size ()} " )
69
69
70
+ target_dtype = input .dtype
71
+
70
72
if len (input .shape ) != 5 :
71
73
raise ValueError ("Expected a 5D tensor" )
72
74
@@ -75,13 +77,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
75
77
76
78
inputs = []
77
79
for i in range (input .size (1 )):
78
- array = input [:, i : i + 1 , ...]. to ( dtype = torch . float32 )
80
+ array = input [:, i : i + 1 , ...]
79
81
mean = array .mean ([2 , 3 , 4 , 5 ], keepdim = True )
80
82
std = array .var ([2 , 3 , 4 , 5 ], unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ()
81
- if self .norm_float16 :
82
- inputs .append (((array - mean ) / std ).to (dtype = torch .float16 ))
83
- else :
84
- inputs .append ((array - mean ) / std )
83
+ inputs .append (((array - mean ) / std ).to (dtype = target_dtype ))
85
84
86
85
del input
87
86
_empty_cuda_cache (self .save_mem )
@@ -375,7 +374,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
375
374
x = self .conv (x )
376
375
return x
377
376
378
-
377
+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
379
378
class MaisiResBlock (nn .Module ):
380
379
"""
381
380
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +416,6 @@ def __init__(
417
416
num_channels = in_channels ,
418
417
eps = norm_eps ,
419
418
affine = True ,
420
- norm_float16 = norm_float16 ,
421
419
print_info = print_info ,
422
420
save_mem = save_mem ,
423
421
)
@@ -439,7 +437,6 @@ def __init__(
439
437
num_channels = out_channels ,
440
438
eps = norm_eps ,
441
439
affine = True ,
442
- norm_float16 = norm_float16 ,
443
440
print_info = print_info ,
444
441
save_mem = save_mem ,
445
442
)
@@ -500,7 +497,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
500
497
out_tensor : torch .Tensor = convert_to_tensor (out )
501
498
return out_tensor
502
499
503
-
500
+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
504
501
class MaisiEncoder (nn .Module ):
505
502
"""
506
503
Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +517,7 @@ class MaisiEncoder(nn.Module):
520
517
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521
518
num_splits: Number of splits for the input tensor.
522
519
dim_split: Dimension of splitting for the input tensor.
523
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
520
+ norm_float16: Deprecated argument .
524
521
print_info: Whether to print information, default to `False`.
525
522
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526
523
"""
@@ -591,7 +588,6 @@ def __init__(
591
588
out_channels = output_channel ,
592
589
num_splits = num_splits ,
593
590
dim_split = dim_split ,
594
- norm_float16 = norm_float16 ,
595
591
print_info = print_info ,
596
592
save_mem = save_mem ,
597
593
)
@@ -660,7 +656,6 @@ def __init__(
660
656
num_channels = num_channels [- 1 ],
661
657
eps = norm_eps ,
662
658
affine = True ,
663
- norm_float16 = norm_float16 ,
664
659
print_info = print_info ,
665
660
save_mem = save_mem ,
666
661
)
@@ -689,7 +684,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
689
684
_empty_cuda_cache (self .save_mem )
690
685
return x
691
686
692
-
687
+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
693
688
class MaisiDecoder (nn .Module ):
694
689
"""
695
690
Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +705,7 @@ class MaisiDecoder(nn.Module):
710
705
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711
706
num_splits: Number of splits for the input tensor.
712
707
dim_split: Dimension of splitting for the input tensor.
713
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
708
+ norm_float16: Deprecated argument .
714
709
print_info: Whether to print information, default to `False`.
715
710
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716
711
"""
@@ -809,7 +804,6 @@ def __init__(
809
804
out_channels = block_out_ch ,
810
805
num_splits = num_splits ,
811
806
dim_split = dim_split ,
812
- norm_float16 = norm_float16 ,
813
807
print_info = print_info ,
814
808
save_mem = save_mem ,
815
809
)
@@ -848,7 +842,6 @@ def __init__(
848
842
num_channels = block_in_ch ,
849
843
eps = norm_eps ,
850
844
affine = True ,
851
- norm_float16 = norm_float16 ,
852
845
print_info = print_info ,
853
846
save_mem = save_mem ,
854
847
)
@@ -878,6 +871,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878
871
return x
879
872
880
873
874
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
881
875
class AutoencoderKlMaisi (AutoencoderKL ):
882
876
"""
883
877
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +895,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901
895
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902
896
num_splits: Number of splits for the input tensor.
903
897
dim_split: Dimension of splitting for the input tensor.
904
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
898
+ norm_float16: Deprecated argument .
905
899
print_info: Whether to print information, default to `False`.
906
900
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907
901
"""
@@ -964,7 +958,6 @@ def __init__(
964
958
use_flash_attention = use_flash_attention ,
965
959
num_splits = num_splits ,
966
960
dim_split = dim_split ,
967
- norm_float16 = norm_float16 ,
968
961
print_info = print_info ,
969
962
save_mem = save_mem ,
970
963
)
@@ -985,7 +978,6 @@ def __init__(
985
978
use_convtranspose = use_convtranspose ,
986
979
num_splits = num_splits ,
987
980
dim_split = dim_split ,
988
- norm_float16 = norm_float16 ,
989
981
print_info = print_info ,
990
982
save_mem = save_mem ,
991
983
)
0 commit comments