@@ -782,17 +782,13 @@ def __init__(
782
782
self .norm = LayerNorm (dim ) if norm_out else nn .Identity () # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
783
783
self .project_out = nn .Linear (dim , dim , bias = False ) if final_proj else nn .Identity ()
784
784
785
- def forward (
786
- self ,
787
- x ,
788
- mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
789
- ):
785
+ def forward (self , x ):
790
786
n , device = x .shape [1 ], x .device
791
787
792
788
attn_bias = self .rel_pos_bias (n , n + 1 , device = device )
793
789
794
790
for attn , ff in self .layers :
795
- x = attn (x , mask = mask , attn_bias = attn_bias ) + x
791
+ x = attn (x , attn_bias = attn_bias ) + x
796
792
x = ff (x ) + x
797
793
798
794
out = self .norm (x )
@@ -806,7 +802,7 @@ def __init__(
806
802
num_time_embeds = 1 ,
807
803
num_image_embeds = 1 ,
808
804
num_text_embeds = 1 ,
809
- attend_all_text_encodings = True ,
805
+ max_text_len = 256 ,
810
806
** kwargs
811
807
):
812
808
super ().__init__ ()
@@ -832,7 +828,10 @@ def __init__(
832
828
self .learned_query = nn .Parameter (torch .randn (dim ))
833
829
self .causal_transformer = CausalTransformer (dim = dim , ** kwargs )
834
830
835
- self .attend_all_text_encodings = attend_all_text_encodings
831
+ # dalle1 learned padding strategy
832
+
833
+ self .max_text_len = max_text_len
834
+ self .null_text_embed = nn .Parameter (torch .randn (1 , max_text_len , dim ))
836
835
837
836
def forward_with_cond_scale (
838
837
self ,
@@ -872,11 +871,28 @@ def forward(
872
871
873
872
if not exists (text_encodings ):
874
873
text_encodings = torch .empty ((batch , 0 , dim ), device = device , dtype = dtype )
874
+
875
+ mask = torch .any (text_encodings != 0. , dim = - 1 )
875
876
876
- if self .attend_all_text_encodings :
877
- mask = torch .ones ((batch , text_encodings .shape [- 2 ]), device = device , dtype = torch .bool )
878
- else :
879
- mask = torch .any (text_encodings != 0. , dim = - 1 )
877
+ # replace any padding in the text encodings with learned padding tokens unique across position
878
+
879
+ text_encodings = text_encodings [:, :self .max_text_len ]
880
+ mask = mask [:, :self .max_text_len ]
881
+
882
+ text_len = text_encodings .shape [- 2 ]
883
+ remainder = self .max_text_len - text_len
884
+
885
+ if remainder > 0 :
886
+ text_encodings = F .pad (text_encodings , (0 , 0 , 0 , remainder ), value = 0. )
887
+ mask = F .pad (mask , (0 , remainder ), value = False )
888
+
889
+ null_text_embeds = self .null_text_embed .to (text_encodings .dtype )
890
+
891
+ text_encodings = torch .where (
892
+ rearrange (mask , 'b n -> b n 1' ),
893
+ text_encodings ,
894
+ null_text_embeds
895
+ )
880
896
881
897
# classifier free guidance
882
898
@@ -910,7 +926,7 @@ def forward(
910
926
911
927
# attend
912
928
913
- tokens = self .causal_transformer (tokens , mask = mask )
929
+ tokens = self .causal_transformer (tokens )
914
930
915
931
# get learned query, which should predict the image embedding (per DDPM timestep)
916
932
0 commit comments