22
22
from monai .networks .blocks import Convolution
23
23
from monai .networks .blocks .spatialattention import SpatialAttentionBlock
24
24
from monai .networks .nets .autoencoderkl import AEKLResBlock , AutoencoderKL
25
+ from monai .utils .deprecate_utils import deprecated_arg
25
26
from monai .utils .type_conversion import convert_to_tensor
26
27
27
28
# Set up logging configuration
@@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
34
35
return
35
36
36
37
38
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
37
39
class MaisiGroupNorm3D (nn .GroupNorm ):
38
40
"""
39
41
Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
43
45
num_channels: Number of channels for the group norm.
44
46
eps: Epsilon value for numerical stability.
45
47
affine: Whether to use learnable affine parameters, default to `True`.
46
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
48
+ norm_float16: Deprecated argument .
47
49
print_info: Whether to print information, default to `False`.
48
50
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
49
51
"""
@@ -59,14 +61,15 @@ def __init__(
59
61
save_mem : bool = True ,
60
62
):
61
63
super ().__init__ (num_groups , num_channels , eps , affine )
62
- self .norm_float16 = norm_float16
63
64
self .print_info = print_info
64
65
self .save_mem = save_mem
65
66
66
67
def forward (self , input : torch .Tensor ) -> torch .Tensor :
67
68
if self .print_info :
68
69
logger .info (f"MaisiGroupNorm3D with input size: { input .size ()} " )
69
70
71
+ target_dtype = input .dtype
72
+
70
73
if len (input .shape ) != 5 :
71
74
raise ValueError ("Expected a 5D tensor" )
72
75
@@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
75
78
76
79
inputs = []
77
80
for i in range (input .size (1 )):
78
- array = input [:, i : i + 1 , ...]. to ( dtype = torch . float32 )
81
+ array = input [:, i : i + 1 , ...]
79
82
mean = array .mean ([2 , 3 , 4 , 5 ], keepdim = True )
80
83
std = array .var ([2 , 3 , 4 , 5 ], unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ()
81
- if self .norm_float16 :
82
- inputs .append (((array - mean ) / std ).to (dtype = torch .float16 ))
83
- else :
84
- inputs .append ((array - mean ) / std )
84
+ inputs .append (((array - mean ) / std ).to (dtype = target_dtype ))
85
85
86
86
del input
87
87
_empty_cuda_cache (self .save_mem )
@@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
376
376
return x
377
377
378
378
379
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
379
380
class MaisiResBlock (nn .Module ):
380
381
"""
381
382
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +418,6 @@ def __init__(
417
418
num_channels = in_channels ,
418
419
eps = norm_eps ,
419
420
affine = True ,
420
- norm_float16 = norm_float16 ,
421
421
print_info = print_info ,
422
422
save_mem = save_mem ,
423
423
)
@@ -439,7 +439,6 @@ def __init__(
439
439
num_channels = out_channels ,
440
440
eps = norm_eps ,
441
441
affine = True ,
442
- norm_float16 = norm_float16 ,
443
442
print_info = print_info ,
444
443
save_mem = save_mem ,
445
444
)
@@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
501
500
return out_tensor
502
501
503
502
503
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
504
504
class MaisiEncoder (nn .Module ):
505
505
"""
506
506
Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module):
520
520
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521
521
num_splits: Number of splits for the input tensor.
522
522
dim_split: Dimension of splitting for the input tensor.
523
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
523
+ norm_float16: Deprecated argument .
524
524
print_info: Whether to print information, default to `False`.
525
525
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526
526
"""
@@ -591,7 +591,6 @@ def __init__(
591
591
out_channels = output_channel ,
592
592
num_splits = num_splits ,
593
593
dim_split = dim_split ,
594
- norm_float16 = norm_float16 ,
595
594
print_info = print_info ,
596
595
save_mem = save_mem ,
597
596
)
@@ -660,7 +659,6 @@ def __init__(
660
659
num_channels = num_channels [- 1 ],
661
660
eps = norm_eps ,
662
661
affine = True ,
663
- norm_float16 = norm_float16 ,
664
662
print_info = print_info ,
665
663
save_mem = save_mem ,
666
664
)
@@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
690
688
return x
691
689
692
690
691
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
693
692
class MaisiDecoder (nn .Module ):
694
693
"""
695
694
Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module):
710
709
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711
710
num_splits: Number of splits for the input tensor.
712
711
dim_split: Dimension of splitting for the input tensor.
713
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
712
+ norm_float16: Deprecated argument .
714
713
print_info: Whether to print information, default to `False`.
715
714
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716
715
"""
@@ -809,7 +808,6 @@ def __init__(
809
808
out_channels = block_out_ch ,
810
809
num_splits = num_splits ,
811
810
dim_split = dim_split ,
812
- norm_float16 = norm_float16 ,
813
811
print_info = print_info ,
814
812
save_mem = save_mem ,
815
813
)
@@ -848,7 +846,6 @@ def __init__(
848
846
num_channels = block_in_ch ,
849
847
eps = norm_eps ,
850
848
affine = True ,
851
- norm_float16 = norm_float16 ,
852
849
print_info = print_info ,
853
850
save_mem = save_mem ,
854
851
)
@@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878
875
return x
879
876
880
877
878
+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
881
879
class AutoencoderKlMaisi (AutoencoderKL ):
882
880
"""
883
881
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901
899
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902
900
num_splits: Number of splits for the input tensor.
903
901
dim_split: Dimension of splitting for the input tensor.
904
- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
902
+ norm_float16: Deprecated argument .
905
903
print_info: Whether to print information, default to `False`.
906
904
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907
905
"""
@@ -964,7 +962,6 @@ def __init__(
964
962
use_flash_attention = use_flash_attention ,
965
963
num_splits = num_splits ,
966
964
dim_split = dim_split ,
967
- norm_float16 = norm_float16 ,
968
965
print_info = print_info ,
969
966
save_mem = save_mem ,
970
967
)
@@ -985,7 +982,6 @@ def __init__(
985
982
use_convtranspose = use_convtranspose ,
986
983
num_splits = num_splits ,
987
984
dim_split = dim_split ,
988
- norm_float16 = norm_float16 ,
989
985
print_info = print_info ,
990
986
save_mem = save_mem ,
991
987
)
0 commit comments