@@ -601,6 +601,7 @@ def __init__(
601
601
* ,
602
602
image_embed_dim = 1024 ,
603
603
text_embed_dim = get_encoded_dim (DEFAULT_T5_NAME ),
604
+ num_resnet_blocks = 1 ,
604
605
cond_dim = None ,
605
606
num_image_tokens = 4 ,
606
607
num_time_tokens = 2 ,
@@ -706,6 +707,7 @@ def __init__(
706
707
707
708
# resnet block klass
708
709
710
+ num_resnet_blocks = cast_tuple (num_resnet_blocks , len (in_out ))
709
711
resnet_groups = cast_tuple (resnet_groups , len (in_out ))
710
712
711
713
assert len (resnet_groups ) == len (in_out )
@@ -722,15 +724,15 @@ def __init__(
722
724
self .ups = nn .ModuleList ([])
723
725
num_resolutions = len (in_out )
724
726
725
- for ind , ((dim_in , dim_out ), groups ) in enumerate (zip (in_out , resnet_groups )):
727
+ for ind , ((dim_in , dim_out ), layer_num_resnet_blocks , groups ) in enumerate (zip (in_out , num_resnet_blocks , resnet_groups )):
726
728
is_first = ind == 0
727
729
is_last = ind >= (num_resolutions - 1 )
728
730
layer_cond_dim = cond_dim if not is_first else None
729
731
730
732
self .downs .append (nn .ModuleList ([
731
733
ResnetBlock (dim_in , dim_out , time_cond_dim = time_cond_dim , groups = groups ),
732
734
Residual (LinearAttention (dim_out , ** attn_kwargs )) if sparse_attn else nn .Identity (),
733
- ResnetBlock (dim_out , dim_out , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ),
735
+ nn . ModuleList ([ ResnetBlock (dim_out , dim_out , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ) for _ in range ( layer_num_resnet_blocks )] ),
734
736
downsample_klass (dim_out ) if not is_last else nn .Identity ()
735
737
]))
736
738
@@ -740,14 +742,14 @@ def __init__(
740
742
self .mid_attn = EinopsToAndFrom ('b c h w' , 'b (h w) c' , Residual (Attention (mid_dim , ** attn_kwargs ))) if attend_at_middle else None
741
743
self .mid_block2 = ResnetBlock (mid_dim , mid_dim , cond_dim = cond_dim , time_cond_dim = time_cond_dim , groups = resnet_groups [- 1 ])
742
744
743
- for ind , ((dim_in , dim_out ), groups ) in enumerate (zip (reversed (in_out [1 :]), reversed (resnet_groups ))):
745
+ for ind , ((dim_in , dim_out ), layer_num_resnet_blocks , groups ) in enumerate (zip (reversed (in_out [1 :]), reversed ( num_resnet_blocks ), reversed (resnet_groups ))):
744
746
is_last = ind >= (num_resolutions - 2 )
745
747
layer_cond_dim = cond_dim if not is_last else None
746
748
747
749
self .ups .append (nn .ModuleList ([
748
750
ResnetBlock (dim_out * 2 , dim_in , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ),
749
751
Residual (LinearAttention (dim_in , ** attn_kwargs )) if sparse_attn else nn .Identity (),
750
- ResnetBlock (dim_in , dim_in , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ),
752
+ nn . ModuleList ([ ResnetBlock (dim_in , dim_in , cond_dim = layer_cond_dim , time_cond_dim = time_cond_dim , groups = groups ) for _ in range ( layer_num_resnet_blocks )] ),
751
753
Upsample (dim_in )
752
754
]))
753
755
@@ -891,10 +893,13 @@ def forward(
891
893
892
894
hiddens = []
893
895
894
- for block1 , sparse_attn , block2 , downsample in self .downs :
895
- x = block1 (x , c , t )
896
+ for init_block , sparse_attn , resnet_blocks , downsample in self .downs :
897
+ x = init_block (x , c , t )
896
898
x = sparse_attn (x )
897
- x = block2 (x , c , t )
899
+
900
+ for resnet_block in resnet_blocks :
901
+ x = resnet_block (x , c , t )
902
+
898
903
hiddens .append (x )
899
904
x = downsample (x )
900
905
@@ -905,11 +910,14 @@ def forward(
905
910
906
911
x = self .mid_block2 (x , mid_c , t )
907
912
908
- for block1 , sparse_attn , block2 , upsample in self .ups :
913
+ for init_block , sparse_attn , resnet_blocks , upsample in self .ups :
909
914
x = torch .cat ((x , hiddens .pop ()), dim = 1 )
910
- x = block1 (x , c , t )
915
+ x = init_block (x , c , t )
911
916
x = sparse_attn (x )
912
- x = block2 (x , c , t )
917
+
918
+ for resnet_block in resnet_blocks :
919
+ x = resnet_block (x , c , t )
920
+
913
921
x = upsample (x )
914
922
915
923
return self .final_conv (x )
@@ -962,7 +970,7 @@ def __init__(
962
970
# get text encoder
963
971
964
972
self .text_encoder_name = text_encoder_name
965
- text_embed_dim = get_encoded_dim (text_encoder_name )
973
+ self . text_embed_dim = get_encoded_dim (text_encoder_name )
966
974
967
975
# construct unets
968
976
@@ -977,7 +985,7 @@ def __init__(
977
985
one_unet = one_unet .cast_model_parameters (
978
986
lowres_cond = not is_first ,
979
987
cond_on_text = self .condition_on_text ,
980
- text_embed_dim = text_embed_dim if self .condition_on_text else None ,
988
+ text_embed_dim = self . text_embed_dim if self .condition_on_text else None ,
981
989
channels = self .channels ,
982
990
channels_out = unet_channels_out
983
991
)
@@ -1211,6 +1219,7 @@ def sample(
1211
1219
1212
1220
assert not (self .condition_on_text and not exists (text_embeds )), 'text or text encodings must be passed into imagen if specified'
1213
1221
assert not (not self .condition_on_text and exists (text_embeds )), 'imagen specified not to be conditioned on text, yet it is presented'
1222
+ assert not (exists (text_embeds ) and text_embeds .shape [- 1 ] != self .text_embed_dim ), f'invalid text embedding dimension being passed in (should be { self .text_embed_dim } )'
1214
1223
1215
1224
img = None
1216
1225
is_cuda = next (self .parameters ()).is_cuda
@@ -1282,6 +1291,8 @@ def forward(
1282
1291
assert not (self .condition_on_text and not exists (text_embeds )), 'text or text encodings must be passed into decoder if specified'
1283
1292
assert not (not self .condition_on_text and exists (text_embeds )), 'decoder specified not to be conditioned on text, yet it is presented'
1284
1293
1294
+ assert not (exists (text_embeds ) and text_embeds .shape [- 1 ] != self .text_embed_dim ), f'invalid text embedding dimension being passed in (should be { self .text_embed_dim } )'
1295
+
1285
1296
lowres_cond_img = lowres_aug_times = None
1286
1297
if exists (prev_image_size ):
1287
1298
lowres_cond_img = resize_image_to (image , prev_image_size )
0 commit comments