9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import Sequence , Tuple , Type , Union
12
+ from typing import Optional , Sequence , Tuple , Type , Union
13
13
14
14
import numpy as np
15
15
import torch
21
21
from monai .networks .blocks import MLPBlock as Mlp
22
22
from monai .networks .blocks import PatchEmbed , UnetOutBlock , UnetrBasicBlock , UnetrUpBlock
23
23
from monai .networks .layers import DropPath , trunc_normal_
24
- from monai .utils import ensure_tuple_rep , optional_import
24
+ from monai .utils import ensure_tuple_rep , look_up_option , optional_import
25
25
26
26
rearrange , _ = optional_import ("einops" , name = "rearrange" )
27
27
28
+ __all__ = [
29
+ "SwinUNETR" ,
30
+ "window_partition" ,
31
+ "window_reverse" ,
32
+ "WindowAttention" ,
33
+ "SwinTransformerBlock" ,
34
+ "PatchMerging" ,
35
+ "PatchMergingV2" ,
36
+ "MERGING_MODE" ,
37
+ "BasicLayer" ,
38
+ "SwinTransformer" ,
39
+ ]
40
+
28
41
29
42
class SwinUNETR (nn .Module ):
30
43
"""
@@ -48,6 +61,7 @@ def __init__(
48
61
normalize : bool = True ,
49
62
use_checkpoint : bool = False ,
50
63
spatial_dims : int = 3 ,
64
+ downsample = "merging" ,
51
65
) -> None :
52
66
"""
53
67
Args:
@@ -64,6 +78,9 @@ def __init__(
64
78
normalize: normalize output intermediate features in each stage.
65
79
use_checkpoint: use gradient checkpointing for reduced memory usage.
66
80
spatial_dims: number of spatial dims.
81
+ downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
82
+ user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
83
+ The default is currently `"merging"` (the original version defined in v0.9.0).
67
84
68
85
Examples::
69
86
@@ -121,6 +138,7 @@ def __init__(
121
138
norm_layer = nn .LayerNorm ,
122
139
use_checkpoint = use_checkpoint ,
123
140
spatial_dims = spatial_dims ,
141
+ downsample = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample ,
124
142
)
125
143
126
144
self .encoder1 = UnetrBasicBlock (
@@ -657,7 +675,7 @@ def forward(self, x, mask_matrix):
657
675
return x
658
676
659
677
660
- class PatchMerging (nn .Module ):
678
+ class PatchMergingV2 (nn .Module ):
661
679
"""
662
680
Patch merging layer based on: "Liu et al.,
663
681
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -695,8 +713,8 @@ def forward(self, x):
695
713
x2 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
696
714
x3 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
697
715
x4 = x [:, 1 ::2 , 0 ::2 , 1 ::2 , :]
698
- x5 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
699
- x6 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
716
+ x5 = x [:, 1 ::2 , 1 ::2 , 0 ::2 , :]
717
+ x6 = x [:, 0 ::2 , 1 ::2 , 1 ::2 , :]
700
718
x7 = x [:, 1 ::2 , 1 ::2 , 1 ::2 , :]
701
719
x = torch .cat ([x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ], - 1 )
702
720
@@ -716,6 +734,36 @@ def forward(self, x):
716
734
return x
717
735
718
736
737
+ class PatchMerging (PatchMergingV2 ):
738
+ """The `PatchMerging` module previously defined in v0.9.0."""
739
+
740
+ def forward (self , x ):
741
+ x_shape = x .size ()
742
+ if len (x_shape ) == 4 :
743
+ return super ().forward (x )
744
+ if len (x_shape ) != 5 :
745
+ raise ValueError (f"expecting 5D x, got { x .shape } ." )
746
+ b , d , h , w , c = x_shape
747
+ pad_input = (h % 2 == 1 ) or (w % 2 == 1 ) or (d % 2 == 1 )
748
+ if pad_input :
749
+ x = F .pad (x , (0 , 0 , 0 , w % 2 , 0 , h % 2 , 0 , d % 2 ))
750
+ x0 = x [:, 0 ::2 , 0 ::2 , 0 ::2 , :]
751
+ x1 = x [:, 1 ::2 , 0 ::2 , 0 ::2 , :]
752
+ x2 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
753
+ x3 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
754
+ x4 = x [:, 1 ::2 , 0 ::2 , 1 ::2 , :]
755
+ x5 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
756
+ x6 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
757
+ x7 = x [:, 1 ::2 , 1 ::2 , 1 ::2 , :]
758
+ x = torch .cat ([x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ], - 1 )
759
+ x = self .norm (x )
760
+ x = self .reduction (x )
761
+ return x
762
+
763
+
764
+ MERGING_MODE = {"merging" : PatchMerging , "mergingv2" : PatchMergingV2 }
765
+
766
+
719
767
def compute_mask (dims , window_size , shift_size , device ):
720
768
"""Computing region masks based on: "Liu et al.,
721
769
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -776,7 +824,7 @@ def __init__(
776
824
drop : float = 0.0 ,
777
825
attn_drop : float = 0.0 ,
778
826
norm_layer : Type [LayerNorm ] = nn .LayerNorm ,
779
- downsample : isinstance = None , # type: ignore
827
+ downsample : Optional [ nn . Module ] = None ,
780
828
use_checkpoint : bool = False ,
781
829
) -> None :
782
830
"""
@@ -791,7 +839,7 @@ def __init__(
791
839
drop: dropout rate.
792
840
attn_drop: attention dropout rate.
793
841
norm_layer: normalization layer.
794
- downsample: downsample layer at the end of the layer.
842
+ downsample: an optional downsampling layer at the end of the layer.
795
843
use_checkpoint: use gradient checkpointing for reduced memory usage.
796
844
"""
797
845
@@ -820,7 +868,7 @@ def __init__(
820
868
]
821
869
)
822
870
self .downsample = downsample
823
- if self .downsample is not None :
871
+ if callable ( self .downsample ) :
824
872
self .downsample = downsample (dim = dim , norm_layer = norm_layer , spatial_dims = len (self .window_size ))
825
873
826
874
def forward (self , x ):
@@ -881,6 +929,7 @@ def __init__(
881
929
patch_norm : bool = False ,
882
930
use_checkpoint : bool = False ,
883
931
spatial_dims : int = 3 ,
932
+ downsample = "merging" ,
884
933
) -> None :
885
934
"""
886
935
Args:
@@ -899,6 +948,9 @@ def __init__(
899
948
patch_norm: add normalization after patch embedding.
900
949
use_checkpoint: use gradient checkpointing for reduced memory usage.
901
950
spatial_dims: spatial dimension.
951
+ downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
952
+ user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
953
+ The default is currently `"merging"` (the original version defined in v0.9.0).
902
954
"""
903
955
904
956
super ().__init__ ()
@@ -920,6 +972,7 @@ def __init__(
920
972
self .layers2 = nn .ModuleList ()
921
973
self .layers3 = nn .ModuleList ()
922
974
self .layers4 = nn .ModuleList ()
975
+ down_sample_mod = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample
923
976
for i_layer in range (self .num_layers ):
924
977
layer = BasicLayer (
925
978
dim = int (embed_dim * 2 ** i_layer ),
@@ -932,7 +985,7 @@ def __init__(
932
985
drop = drop_rate ,
933
986
attn_drop = attn_drop_rate ,
934
987
norm_layer = norm_layer ,
935
- downsample = PatchMerging ,
988
+ downsample = down_sample_mod ,
936
989
use_checkpoint = use_checkpoint ,
937
990
)
938
991
if i_layer == 0 :
0 commit comments