Skip to content

Commit 356d2d2

Browse files
authored
4757 update patch merging (#4758)
* update patch merging Signed-off-by: Wenqi Li <[email protected]> * fixes unit tests Signed-off-by: Wenqi Li <[email protected]>
1 parent 178e973 commit 356d2d2

File tree

4 files changed

+72
-11
lines changed

4 files changed

+72
-11
lines changed

Diff for: monai/networks/nets/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
seresnext50,
8282
seresnext101,
8383
)
84-
from .swin_unetr import SwinUNETR
84+
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
8585
from .torchvision_fc import TorchVisionFCModel
8686
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
8787
from .unet import UNet, Unet

Diff for: monai/networks/nets/swin_unetr.py

+62-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Sequence, Tuple, Type, Union
12+
from typing import Optional, Sequence, Tuple, Type, Union
1313

1414
import numpy as np
1515
import torch
@@ -21,10 +21,23 @@
2121
from monai.networks.blocks import MLPBlock as Mlp
2222
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
2323
from monai.networks.layers import DropPath, trunc_normal_
24-
from monai.utils import ensure_tuple_rep, optional_import
24+
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
2525

2626
rearrange, _ = optional_import("einops", name="rearrange")
2727

28+
__all__ = [
29+
"SwinUNETR",
30+
"window_partition",
31+
"window_reverse",
32+
"WindowAttention",
33+
"SwinTransformerBlock",
34+
"PatchMerging",
35+
"PatchMergingV2",
36+
"MERGING_MODE",
37+
"BasicLayer",
38+
"SwinTransformer",
39+
]
40+
2841

2942
class SwinUNETR(nn.Module):
3043
"""
@@ -48,6 +61,7 @@ def __init__(
4861
normalize: bool = True,
4962
use_checkpoint: bool = False,
5063
spatial_dims: int = 3,
64+
downsample="merging",
5165
) -> None:
5266
"""
5367
Args:
@@ -64,6 +78,9 @@ def __init__(
6478
normalize: normalize output intermediate features in each stage.
6579
use_checkpoint: use gradient checkpointing for reduced memory usage.
6680
spatial_dims: number of spatial dims.
81+
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
82+
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
83+
The default is currently `"merging"` (the original version defined in v0.9.0).
6784
6885
Examples::
6986
@@ -121,6 +138,7 @@ def __init__(
121138
norm_layer=nn.LayerNorm,
122139
use_checkpoint=use_checkpoint,
123140
spatial_dims=spatial_dims,
141+
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
124142
)
125143

126144
self.encoder1 = UnetrBasicBlock(
@@ -657,7 +675,7 @@ def forward(self, x, mask_matrix):
657675
return x
658676

659677

660-
class PatchMerging(nn.Module):
678+
class PatchMergingV2(nn.Module):
661679
"""
662680
Patch merging layer based on: "Liu et al.,
663681
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -695,8 +713,8 @@ def forward(self, x):
695713
x2 = x[:, 0::2, 1::2, 0::2, :]
696714
x3 = x[:, 0::2, 0::2, 1::2, :]
697715
x4 = x[:, 1::2, 0::2, 1::2, :]
698-
x5 = x[:, 0::2, 1::2, 0::2, :]
699-
x6 = x[:, 0::2, 0::2, 1::2, :]
716+
x5 = x[:, 1::2, 1::2, 0::2, :]
717+
x6 = x[:, 0::2, 1::2, 1::2, :]
700718
x7 = x[:, 1::2, 1::2, 1::2, :]
701719
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
702720

@@ -716,6 +734,36 @@ def forward(self, x):
716734
return x
717735

718736

737+
class PatchMerging(PatchMergingV2):
738+
"""The `PatchMerging` module previously defined in v0.9.0."""
739+
740+
def forward(self, x):
741+
x_shape = x.size()
742+
if len(x_shape) == 4:
743+
return super().forward(x)
744+
if len(x_shape) != 5:
745+
raise ValueError(f"expecting 5D x, got {x.shape}.")
746+
b, d, h, w, c = x_shape
747+
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
748+
if pad_input:
749+
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
750+
x0 = x[:, 0::2, 0::2, 0::2, :]
751+
x1 = x[:, 1::2, 0::2, 0::2, :]
752+
x2 = x[:, 0::2, 1::2, 0::2, :]
753+
x3 = x[:, 0::2, 0::2, 1::2, :]
754+
x4 = x[:, 1::2, 0::2, 1::2, :]
755+
x5 = x[:, 0::2, 1::2, 0::2, :]
756+
x6 = x[:, 0::2, 0::2, 1::2, :]
757+
x7 = x[:, 1::2, 1::2, 1::2, :]
758+
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
759+
x = self.norm(x)
760+
x = self.reduction(x)
761+
return x
762+
763+
764+
MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}
765+
766+
719767
def compute_mask(dims, window_size, shift_size, device):
720768
"""Computing region masks based on: "Liu et al.,
721769
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -776,7 +824,7 @@ def __init__(
776824
drop: float = 0.0,
777825
attn_drop: float = 0.0,
778826
norm_layer: Type[LayerNorm] = nn.LayerNorm,
779-
downsample: isinstance = None, # type: ignore
827+
downsample: Optional[nn.Module] = None,
780828
use_checkpoint: bool = False,
781829
) -> None:
782830
"""
@@ -791,7 +839,7 @@ def __init__(
791839
drop: dropout rate.
792840
attn_drop: attention dropout rate.
793841
norm_layer: normalization layer.
794-
downsample: downsample layer at the end of the layer.
842+
downsample: an optional downsampling layer at the end of the layer.
795843
use_checkpoint: use gradient checkpointing for reduced memory usage.
796844
"""
797845

@@ -820,7 +868,7 @@ def __init__(
820868
]
821869
)
822870
self.downsample = downsample
823-
if self.downsample is not None:
871+
if callable(self.downsample):
824872
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
825873

826874
def forward(self, x):
@@ -881,6 +929,7 @@ def __init__(
881929
patch_norm: bool = False,
882930
use_checkpoint: bool = False,
883931
spatial_dims: int = 3,
932+
downsample="merging",
884933
) -> None:
885934
"""
886935
Args:
@@ -899,6 +948,9 @@ def __init__(
899948
patch_norm: add normalization after patch embedding.
900949
use_checkpoint: use gradient checkpointing for reduced memory usage.
901950
spatial_dims: spatial dimension.
951+
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
952+
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
953+
The default is currently `"merging"` (the original version defined in v0.9.0).
902954
"""
903955

904956
super().__init__()
@@ -920,6 +972,7 @@ def __init__(
920972
self.layers2 = nn.ModuleList()
921973
self.layers3 = nn.ModuleList()
922974
self.layers4 = nn.ModuleList()
975+
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
923976
for i_layer in range(self.num_layers):
924977
layer = BasicLayer(
925978
dim=int(embed_dim * 2**i_layer),
@@ -932,7 +985,7 @@ def __init__(
932985
drop=drop_rate,
933986
attn_drop=attn_drop_rate,
934987
norm_layer=norm_layer,
935-
downsample=PatchMerging,
988+
downsample=down_sample_mod,
936989
use_checkpoint=use_checkpoint,
937990
)
938991
if i_layer == 0:

Diff for: tests/test_swin_unetr.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
from parameterized import parameterized
1717

1818
from monai.networks import eval_mode
19-
from monai.networks.nets.swin_unetr import PatchMerging, SwinUNETR
19+
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
2020
from monai.utils import optional_import
2121

2222
einops, has_einops = optional_import("einops")
2323

2424
TEST_CASE_SWIN_UNETR = []
25+
case_idx = 0
26+
test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2]
2527
for attn_drop_rate in [0.4]:
2628
for in_channels in [1]:
2729
for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]:
@@ -39,10 +41,12 @@
3941
"depths": depth,
4042
"norm_name": norm_name,
4143
"attn_drop_rate": attn_drop_rate,
44+
"downsample": test_merging_mode[case_idx % 4],
4245
},
4346
(2, in_channels, *img_size),
4447
(2, out_channels, *img_size),
4548
]
49+
case_idx += 1
4650
TEST_CASE_SWIN_UNETR.append(test_case)
4751

4852

Diff for: tests/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import operator
1818
import os
1919
import queue
20+
import ssl
2021
import sys
2122
import tempfile
2223
import time
@@ -123,6 +124,9 @@ def skip_if_downloading_fails():
123124
yield
124125
except (ContentTooShortError, HTTPError, ConnectionError) as e:
125126
raise unittest.SkipTest(f"error while downloading: {e}") from e
127+
except ssl.SSLError as ssl_e:
128+
if "decryption failed" in str(ssl_e):
129+
raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e
126130
except RuntimeError as rt_e:
127131
if "unexpected EOF" in str(rt_e):
128132
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download

0 commit comments

Comments
 (0)