@@ -89,19 +89,24 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
89
89
def forward (self , x , mask = None ):
90
90
b , n , _ , h , img_size , kernel_size , dilation , seq_len , device = * x .shape , self .heads , self .image_size , self .kernel_size , self .dilation , self .seq_len , x .device
91
91
92
- if n < seq_len :
93
- padding = seq_len - n
94
- mask = default (mask , lambda : torch .ones (b , n , device = device ).bool ())
95
- x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
96
- mask = F .pad (x , (0 , padding ), value = False )
92
+ img_seq_len = img_size ** 2
93
+ text_len = seq_len + 1 - img_seq_len
94
+
95
+ # padding
96
+
97
+ padding = seq_len - n + 1
98
+ mask = default (mask , lambda : torch .ones (b , text_len , device = device ).bool ())
99
+
100
+ x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
101
+ mask = mask [:, :text_len ]
102
+
103
+ # derive query / keys / values
97
104
98
105
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
99
106
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
100
107
101
108
q *= self .scale
102
109
103
- img_seq_len = img_size ** 2
104
- text_len = seq_len - img_seq_len
105
110
((q_text , q_img ), (k_text , k_img ), (v_text , v_img )) = map (lambda t : (t [:, img_seq_len :], t [:, - img_seq_len :]), (q , k , v ))
106
111
107
112
# text attention
@@ -110,8 +115,8 @@ def forward(self, x, mask = None):
110
115
mask_value = max_neg_value (dots_text )
111
116
112
117
i , j = dots_text .shape [- 2 :]
113
- mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
114
- dots_text .masked_fill (mask , mask_value )
118
+ text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
119
+ dots_text .masked_fill (text_causal_mask , mask_value )
115
120
116
121
attn_text = dots_text .softmax (dim = - 1 )
117
122
out_text = einsum ('b i j, b j d -> b i d' , attn_text , v_text )
@@ -146,12 +151,17 @@ def forward(self, x, mask = None):
146
151
# mask image attention
147
152
148
153
q_img_indices = rearrange (img_seq , 'i -> () i ()' )
149
- mask = q_img_indices >= k_img_indices
154
+ causal_mask = q_img_indices >= k_img_indices
155
+
156
+ # concat text mask with image causal mask
157
+
158
+ causal_mask = repeat (causal_mask , '() i j -> b i j' , b = b * h )
159
+ mask = repeat (mask , 'b j -> (b h) i j' , i = i , h = h )
160
+ mask = torch .cat ((~ mask , causal_mask ), dim = - 1 )
150
161
151
162
# image can attend to all of text
152
163
153
- mask = F .pad (mask , (text_len , 0 ), value = True )
154
- dots_image .masked_fill_ (~ mask , mask_value )
164
+ dots_image .masked_fill_ (mask , mask_value )
155
165
156
166
attn_image = dots_image .softmax (dim = - 1 )
157
167
out_image = einsum ('b i j, b i j d -> b i d' , attn_image , v_img )
@@ -188,20 +198,24 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
188
198
def forward (self , x , mask = None ):
189
199
b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
190
200
191
- if n < seq_len :
192
- padding = seq_len - n
193
- mask = default (mask , lambda : torch .ones (b , n , device = device ).bool ())
194
- x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
195
- mask = F .pad (x , (0 , padding ), value = False )
201
+ img_seq_len = img_size ** 2
202
+ text_len = seq_len + 1 - img_seq_len
203
+
204
+ # padding
205
+
206
+ padding = seq_len - n + 1
207
+ mask = default (mask , lambda : torch .ones (b , text_len , device = device ).bool ())
208
+
209
+ x = F .pad (x , (0 , 0 , 0 , padding ), value = 0 )
210
+ mask = mask [:, :text_len ]
211
+
212
+ # derive queries / keys / values
196
213
197
214
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
198
215
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
199
216
200
217
q *= self .scale
201
218
202
- img_seq_len = img_size ** 2
203
- text_len = seq_len - img_seq_len
204
-
205
219
((q_text , q_img ), (k_text , k_img ), (v_text , v_img )) = map (lambda t : (t [:, img_seq_len :], t [:, - img_seq_len :]), (q , k , v ))
206
220
207
221
# text attention
@@ -210,8 +224,8 @@ def forward(self, x, mask = None):
210
224
mask_value = max_neg_value (dots_text )
211
225
212
226
i , j = dots_text .shape [- 2 :]
213
- mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
214
- dots_text .masked_fill (mask , mask_value )
227
+ text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
228
+ dots_text .masked_fill (text_causal_mask , mask_value )
215
229
216
230
attn_text = dots_text .softmax (dim = - 1 )
217
231
out_text = einsum ('b i j, b j d -> b i d' , attn_text , v_text )
@@ -237,13 +251,21 @@ def forward(self, x, mask = None):
237
251
238
252
# mask so image has full attention to text, but causal along axis
239
253
240
- i , j = dots_image .shape [- 2 :]
241
- mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
254
+ bh , i , j = dots_image .shape
255
+ causal_mask = torch .ones (i , img_size , device = device ).triu_ (img_size - i + 1 ).bool ()
256
+ causal_mask = repeat (causal_mask , 'i j -> b i j' , b = bh )
257
+
258
+ mask = repeat (mask , 'b j -> (b r) i j' , r = (bh // b ), i = i )
259
+ mask = torch .cat ((~ mask , causal_mask ), dim = - 1 )
260
+
242
261
dots_image .masked_fill_ (mask , mask_value )
243
262
244
263
# attention.
245
264
246
265
attn_image = dots_image .softmax (dim = - 1 )
266
+
267
+ # aggregate
268
+
247
269
out_image = einsum ('b i j, b j d -> b i d' , attn_image , v_img )
248
270
249
271
# merge back axis
0 commit comments