@@ -770,6 +770,7 @@ def __init__(
770770 threshold_ema_dead_code = 0 ,
771771 channel_last = True ,
772772 accept_image_fmap = False ,
773+ accept_3d_fmap = False ,
773774 commitment_weight = 1. ,
774775 commitment_use_cross_entropy_loss = False ,
775776 orthogonal_reg_weight = 0. ,
@@ -906,6 +907,7 @@ def __init__(
906907 self .codebook_size = codebook_size
907908
908909 self .accept_image_fmap = accept_image_fmap
910+ self .accept_3d_fmap = accept_3d_fmap
909911 self .channel_last = channel_last
910912
911913 self .register_buffer ('zero' , tensor (0. ), persistent = False )
@@ -953,7 +955,7 @@ def get_codes_from_indices(self, indices):
953955 codes = rearrange (codes , 'b h n d -> b n (h d)' )
954956 codes = unpack_one (codes , 'b * d' )
955957
956- if not self .channel_last :
958+ if not self .channel_last or self . accept_image_fmap or self . accept_3d_fmap :
957959 codes = rearrange (codes , 'b ... d -> b d ...' )
958960
959961 return codes
@@ -1000,7 +1002,11 @@ def update_ema_indices(self, x, indices, mask = None):
10001002 height , width = x .shape [- 2 :]
10011003 x = rearrange (x , 'b c h w -> b (h w) c' )
10021004
1003- if not self .channel_last and not self .accept_image_fmap :
1005+ if self .accept_3d_fmap :
1006+ assert not exists (mask )
1007+ x = rearrange (x , 'b c d h w -> b (d h w) c' )
1008+
1009+ if not self .channel_last and not self .accept_image_fmap and not self .accept_3d_fmap :
10041010 x = rearrange (x , 'b d n -> b n d' )
10051011
10061012 x = self .project_in (x )
@@ -1016,6 +1022,9 @@ def update_ema_indices(self, x, indices, mask = None):
10161022 if self .accept_image_fmap :
10171023 indices = rearrange (indices , 'b h w ... -> b (h w) ...' )
10181024
1025+ if self .accept_3d_fmap :
1026+ indices = rearrange (indices , 'b d h w ... -> b (d h w) ...' )
1027+
10191028 if x .ndim == 2 : # only one token
10201029 indices = rearrange (indices , 'b ... -> b 1 ...' )
10211030
@@ -1059,7 +1068,7 @@ def forward(
10591068
10601069 shape , dtype , device , heads , is_multiheaded , codebook_size , return_loss = x .shape , x .dtype , x .device , self .heads , self .heads > 1 , self .codebook_size , exists (indices )
10611070
1062- need_transpose = not self .channel_last and not self .accept_image_fmap
1071+ need_transpose = not self .channel_last and not self .accept_image_fmap and not self . accept_3d_fmap
10631072 should_inplace_optimize = exists (self .in_place_codebook_optimizer )
10641073
10651074 # rearrange inputs
@@ -1069,6 +1078,11 @@ def forward(
10691078 height , width = x .shape [- 2 :]
10701079 x = rearrange (x , 'b c h w -> b (h w) c' )
10711080
1081+ if self .accept_3d_fmap :
1082+ assert not exists (mask )
1083+ depth , height , width = x .shape [- 3 :]
1084+ x = rearrange (x , 'b c d h w -> b (d h w) c' )
1085+
10721086 if need_transpose :
10731087 x = rearrange (x , 'b d n -> b n d' )
10741088
@@ -1196,6 +1210,9 @@ def calculate_ce_loss(codes):
11961210 if self .accept_image_fmap :
11971211 embed_ind = rearrange (embed_ind , 'b (h w) ... -> b h w ...' , h = height , w = width )
11981212
1213+ if self .accept_3d_fmap :
1214+ embed_ind = rearrange (embed_ind , 'b (d h w) ... -> b d h w ...' , d = depth , h = height , w = width )
1215+
11991216 if only_one :
12001217 embed_ind = rearrange (embed_ind , 'b 1 ... -> b ...' )
12011218
@@ -1289,6 +1306,9 @@ def calculate_ce_loss(codes):
12891306 if self .accept_image_fmap :
12901307 quantize = rearrange (quantize , 'b (h w) c -> b c h w' , h = height , w = width )
12911308
1309+ if self .accept_3d_fmap :
1310+ quantize = rearrange (quantize , "b (d h w) c -> b c d h w" , d = depth , h = height , w = width )
1311+
12921312 if only_one :
12931313 quantize = rearrange (quantize , 'b 1 d -> b d' )
12941314
0 commit comments