Skip to content

Commit 8888a48

Browse files
committed
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16
Signed-off-by: John Zielke <[email protected]>
1 parent 8dcb9dc commit 8888a48

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

Diff for: monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from monai.networks.blocks import Convolution
2323
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
2424
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
25+
from monai.utils.deprecate_utils import deprecated_arg
2526
from monai.utils.type_conversion import convert_to_tensor
2627

2728
# Set up logging configuration
@@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
3435
return
3536

3637

38+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
3739
class MaisiGroupNorm3D(nn.GroupNorm):
3840
"""
3941
Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
4345
num_channels: Number of channels for the group norm.
4446
eps: Epsilon value for numerical stability.
4547
affine: Whether to use learnable affine parameters, default to `True`.
46-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
48+
norm_float16: Deprecated argument.
4749
print_info: Whether to print information, default to `False`.
4850
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
4951
"""
@@ -59,14 +61,15 @@ def __init__(
5961
save_mem: bool = True,
6062
):
6163
super().__init__(num_groups, num_channels, eps, affine)
62-
self.norm_float16 = norm_float16
6364
self.print_info = print_info
6465
self.save_mem = save_mem
6566

6667
def forward(self, input: torch.Tensor) -> torch.Tensor:
6768
if self.print_info:
6869
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
6970

71+
target_dtype = input.dtype
72+
7073
if len(input.shape) != 5:
7174
raise ValueError("Expected a 5D tensor")
7275

@@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7578

7679
inputs = []
7780
for i in range(input.size(1)):
78-
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
81+
array = input[:, i : i + 1, ...]
7982
mean = array.mean([2, 3, 4, 5], keepdim=True)
8083
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
81-
if self.norm_float16:
82-
inputs.append(((array - mean) / std).to(dtype=torch.float16))
83-
else:
84-
inputs.append((array - mean) / std)
84+
inputs.append(((array - mean) / std).to(dtype=target_dtype))
8585

8686
del input
8787
_empty_cuda_cache(self.save_mem)
@@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
376376
return x
377377

378378

379+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
379380
class MaisiResBlock(nn.Module):
380381
"""
381382
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +418,6 @@ def __init__(
417418
num_channels=in_channels,
418419
eps=norm_eps,
419420
affine=True,
420-
norm_float16=norm_float16,
421421
print_info=print_info,
422422
save_mem=save_mem,
423423
)
@@ -439,7 +439,6 @@ def __init__(
439439
num_channels=out_channels,
440440
eps=norm_eps,
441441
affine=True,
442-
norm_float16=norm_float16,
443442
print_info=print_info,
444443
save_mem=save_mem,
445444
)
@@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
501500
return out_tensor
502501

503502

503+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
504504
class MaisiEncoder(nn.Module):
505505
"""
506506
Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module):
520520
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521521
num_splits: Number of splits for the input tensor.
522522
dim_split: Dimension of splitting for the input tensor.
523-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
523+
norm_float16: Deprecated argument.
524524
print_info: Whether to print information, default to `False`.
525525
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526526
"""
@@ -591,7 +591,6 @@ def __init__(
591591
out_channels=output_channel,
592592
num_splits=num_splits,
593593
dim_split=dim_split,
594-
norm_float16=norm_float16,
595594
print_info=print_info,
596595
save_mem=save_mem,
597596
)
@@ -660,7 +659,6 @@ def __init__(
660659
num_channels=num_channels[-1],
661660
eps=norm_eps,
662661
affine=True,
663-
norm_float16=norm_float16,
664662
print_info=print_info,
665663
save_mem=save_mem,
666664
)
@@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
690688
return x
691689

692690

691+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
693692
class MaisiDecoder(nn.Module):
694693
"""
695694
Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module):
710709
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711710
num_splits: Number of splits for the input tensor.
712711
dim_split: Dimension of splitting for the input tensor.
713-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
712+
norm_float16: Deprecated argument.
714713
print_info: Whether to print information, default to `False`.
715714
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716715
"""
@@ -809,7 +808,6 @@ def __init__(
809808
out_channels=block_out_ch,
810809
num_splits=num_splits,
811810
dim_split=dim_split,
812-
norm_float16=norm_float16,
813811
print_info=print_info,
814812
save_mem=save_mem,
815813
)
@@ -848,7 +846,6 @@ def __init__(
848846
num_channels=block_in_ch,
849847
eps=norm_eps,
850848
affine=True,
851-
norm_float16=norm_float16,
852849
print_info=print_info,
853850
save_mem=save_mem,
854851
)
@@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878875
return x
879876

880877

878+
@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
881879
class AutoencoderKlMaisi(AutoencoderKL):
882880
"""
883881
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901899
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902900
num_splits: Number of splits for the input tensor.
903901
dim_split: Dimension of splitting for the input tensor.
904-
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
902+
norm_float16: Deprecated argument.
905903
print_info: Whether to print information, default to `False`.
906904
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907905
"""
@@ -964,7 +962,6 @@ def __init__(
964962
use_flash_attention=use_flash_attention,
965963
num_splits=num_splits,
966964
dim_split=dim_split,
967-
norm_float16=norm_float16,
968965
print_info=print_info,
969966
save_mem=save_mem,
970967
)
@@ -985,7 +982,6 @@ def __init__(
985982
use_convtranspose=use_convtranspose,
986983
num_splits=num_splits,
987984
dim_split=dim_split,
988-
norm_float16=norm_float16,
989985
print_info=print_info,
990986
save_mem=save_mem,
991987
)

Diff for: tests/test_autoencoderkl_maisi.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,38 @@
7575
else:
7676
CASES = CASES_NO_ATTENTION
7777

78+
test_dtypes = [torch.float32]
79+
if device.type == "cuda":
80+
test_dtypes.append(torch.bfloat16)
81+
test_dtypes.append(torch.float16)
82+
83+
DTYPE_CASES = []
84+
for dtype in test_dtypes:
85+
for case in CASES:
86+
DTYPE_CASES.append(case + [dtype])
87+
7888

7989
class TestAutoencoderKlMaisi(unittest.TestCase):
8090

81-
@parameterized.expand(CASES)
82-
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
83-
net = AutoencoderKlMaisi(**input_param).to(device)
91+
@parameterized.expand(DTYPE_CASES)
92+
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
93+
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
8494
with eval_mode(net):
85-
result = net.forward(torch.randn(input_shape).to(device))
95+
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
8696
self.assertEqual(result[0].shape, expected_shape)
8797
self.assertEqual(result[1].shape, expected_latent_shape)
8898
self.assertEqual(result[2].shape, expected_latent_shape)
8999

90-
@parameterized.expand(CASES)
100+
@parameterized.expand(DTYPE_CASES)
91101
@SkipIfBeforePyTorchVersion((1, 11))
92102
def test_shape_with_convtranspose_and_checkpointing(
93-
self, input_param, input_shape, expected_shape, expected_latent_shape
103+
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
94104
):
95105
input_param = input_param.copy()
96106
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
97-
net = AutoencoderKlMaisi(**input_param).to(device)
107+
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
98108
with eval_mode(net):
99-
result = net.forward(torch.randn(input_shape).to(device))
109+
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
100110
self.assertEqual(result[0].shape, expected_shape)
101111
self.assertEqual(result[1].shape, expected_latent_shape)
102112
self.assertEqual(result[2].shape, expected_latent_shape)

0 commit comments

Comments
 (0)