Skip to content

Commit 81865e4

Browse files
committed
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16
Signed-off-by: John Zielke <[email protected]>
1 parent 8dcb9dc commit 81865e4

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

Diff for: monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

+14-22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
2424
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
2525
from monai.utils.type_conversion import convert_to_tensor
26+
from monai.utils.deprecate_utils import deprecated_arg
2627

2728
# Set up logging configuration
2829
logger = logging.getLogger(__name__)
@@ -33,7 +34,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
3334
torch.cuda.empty_cache()
3435
return
3536

36-
37+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
3738
class MaisiGroupNorm3D(nn.GroupNorm):
3839
"""
3940
Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +44,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
4344
num_channels: Number of channels for the group norm.
4445
eps: Epsilon value for numerical stability.
4546
affine: Whether to use learnable affine parameters, default to `True`.
46-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
47+
norm_float16: Deprecated argument.
4748
print_info: Whether to print information, default to `False`.
4849
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
4950
"""
@@ -59,14 +60,15 @@ def __init__(
5960
save_mem: bool = True,
6061
):
6162
super().__init__(num_groups, num_channels, eps, affine)
62-
self.norm_float16 = norm_float16
6363
self.print_info = print_info
6464
self.save_mem = save_mem
6565

6666
def forward(self, input: torch.Tensor) -> torch.Tensor:
6767
if self.print_info:
6868
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
6969

70+
target_dtype = input.dtype
71+
7072
if len(input.shape) != 5:
7173
raise ValueError("Expected a 5D tensor")
7274

@@ -75,13 +77,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7577

7678
inputs = []
7779
for i in range(input.size(1)):
78-
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
80+
array = input[:, i : i + 1, ...]
7981
mean = array.mean([2, 3, 4, 5], keepdim=True)
8082
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
81-
if self.norm_float16:
82-
inputs.append(((array - mean) / std).to(dtype=torch.float16))
83-
else:
84-
inputs.append((array - mean) / std)
83+
inputs.append(((array - mean) / std).to(dtype=target_dtype))
8584

8685
del input
8786
_empty_cuda_cache(self.save_mem)
@@ -375,7 +374,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
375374
x = self.conv(x)
376375
return x
377376

378-
377+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
379378
class MaisiResBlock(nn.Module):
380379
"""
381380
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +416,6 @@ def __init__(
417416
num_channels=in_channels,
418417
eps=norm_eps,
419418
affine=True,
420-
norm_float16=norm_float16,
421419
print_info=print_info,
422420
save_mem=save_mem,
423421
)
@@ -439,7 +437,6 @@ def __init__(
439437
num_channels=out_channels,
440438
eps=norm_eps,
441439
affine=True,
442-
norm_float16=norm_float16,
443440
print_info=print_info,
444441
save_mem=save_mem,
445442
)
@@ -500,7 +497,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
500497
out_tensor: torch.Tensor = convert_to_tensor(out)
501498
return out_tensor
502499

503-
500+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
504501
class MaisiEncoder(nn.Module):
505502
"""
506503
Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +517,7 @@ class MaisiEncoder(nn.Module):
520517
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521518
num_splits: Number of splits for the input tensor.
522519
dim_split: Dimension of splitting for the input tensor.
523-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
520+
norm_float16: Deprecated argument.
524521
print_info: Whether to print information, default to `False`.
525522
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526523
"""
@@ -591,7 +588,6 @@ def __init__(
591588
out_channels=output_channel,
592589
num_splits=num_splits,
593590
dim_split=dim_split,
594-
norm_float16=norm_float16,
595591
print_info=print_info,
596592
save_mem=save_mem,
597593
)
@@ -660,7 +656,6 @@ def __init__(
660656
num_channels=num_channels[-1],
661657
eps=norm_eps,
662658
affine=True,
663-
norm_float16=norm_float16,
664659
print_info=print_info,
665660
save_mem=save_mem,
666661
)
@@ -689,7 +684,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
689684
_empty_cuda_cache(self.save_mem)
690685
return x
691686

692-
687+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
693688
class MaisiDecoder(nn.Module):
694689
"""
695690
Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +705,7 @@ class MaisiDecoder(nn.Module):
710705
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711706
num_splits: Number of splits for the input tensor.
712707
dim_split: Dimension of splitting for the input tensor.
713-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
708+
norm_float16: Deprecated argument.
714709
print_info: Whether to print information, default to `False`.
715710
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716711
"""
@@ -809,7 +804,6 @@ def __init__(
809804
out_channels=block_out_ch,
810805
num_splits=num_splits,
811806
dim_split=dim_split,
812-
norm_float16=norm_float16,
813807
print_info=print_info,
814808
save_mem=save_mem,
815809
)
@@ -848,7 +842,6 @@ def __init__(
848842
num_channels=block_in_ch,
849843
eps=norm_eps,
850844
affine=True,
851-
norm_float16=norm_float16,
852845
print_info=print_info,
853846
save_mem=save_mem,
854847
)
@@ -878,6 +871,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878871
return x
879872

880873

874+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
881875
class AutoencoderKlMaisi(AutoencoderKL):
882876
"""
883877
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +895,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901895
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902896
num_splits: Number of splits for the input tensor.
903897
dim_split: Dimension of splitting for the input tensor.
904-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
898+
norm_float16: Deprecated argument.
905899
print_info: Whether to print information, default to `False`.
906900
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907901
"""
@@ -964,7 +958,6 @@ def __init__(
964958
use_flash_attention=use_flash_attention,
965959
num_splits=num_splits,
966960
dim_split=dim_split,
967-
norm_float16=norm_float16,
968961
print_info=print_info,
969962
save_mem=save_mem,
970963
)
@@ -985,7 +978,6 @@ def __init__(
985978
use_convtranspose=use_convtranspose,
986979
num_splits=num_splits,
987980
dim_split=dim_split,
988-
norm_float16=norm_float16,
989981
print_info=print_info,
990982
save_mem=save_mem,
991983
)

Diff for: tests/test_autoencoderkl_maisi.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,37 @@
7575
else:
7676
CASES = CASES_NO_ATTENTION
7777

78+
test_dtypes = [torch.float32, torch.float16]
79+
if device.type == "cuda":
80+
test_dtypes.append(torch.bfloat16)
81+
82+
DTYPE_CASES = []
83+
for dtype in test_dtypes:
84+
for case in CASES:
85+
DTYPE_CASES.append(case + [dtype])
86+
7887

7988
class TestAutoencoderKlMaisi(unittest.TestCase):
8089

81-
@parameterized.expand(CASES)
82-
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
83-
net = AutoencoderKlMaisi(**input_param).to(device)
90+
@parameterized.expand(DTYPE_CASES)
91+
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
92+
net = AutoencoderKlMaisi(**input_param).to(device=device,dtype=dtype)
8493
with eval_mode(net):
85-
result = net.forward(torch.randn(input_shape).to(device))
94+
result = net.forward(torch.randn(input_shape).to(device=device,dtype=dtype))
8695
self.assertEqual(result[0].shape, expected_shape)
8796
self.assertEqual(result[1].shape, expected_latent_shape)
8897
self.assertEqual(result[2].shape, expected_latent_shape)
8998

90-
@parameterized.expand(CASES)
99+
@parameterized.expand(DTYPE_CASES)
91100
@SkipIfBeforePyTorchVersion((1, 11))
92101
def test_shape_with_convtranspose_and_checkpointing(
93-
self, input_param, input_shape, expected_shape, expected_latent_shape
102+
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
94103
):
95104
input_param = input_param.copy()
96105
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
97-
net = AutoencoderKlMaisi(**input_param).to(device)
106+
net = AutoencoderKlMaisi(**input_param).to(device=device,dtype=dtype)
98107
with eval_mode(net):
99-
result = net.forward(torch.randn(input_shape).to(device))
108+
result = net.forward(torch.randn(input_shape).to(device=device,dtype=dtype))
100109
self.assertEqual(result[0].shape, expected_shape)
101110
self.assertEqual(result[1].shape, expected_latent_shape)
102111
self.assertEqual(result[2].shape, expected_latent_shape)

0 commit comments

Comments
 (0)