@@ -164,10 +164,12 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
164
164
# image attention
165
165
166
166
effective_kernel_size = (kernel_size - 1 ) * dilation + 1
167
- padding = effective_kernel_size // 2
167
+ same_padding = effective_kernel_size // 2
168
+ causal_padding = (same_padding * 2 , 0 , same_padding * 2 , 0 )
168
169
169
170
k_img , v_img = map (lambda t : rearrange (t , 'b (h w) c -> b c h w' , h = img_size ), (k_img , v_img ))
170
- k_img , v_img = map (lambda t : F .unfold (t , kernel_size , padding = padding , dilation = dilation ), (k_img , v_img ))
171
+ k_img , v_img = map (lambda t : F .pad (t , causal_padding ), (k_img , v_img ))
172
+ k_img , v_img = map (lambda t : F .unfold (t , kernel_size , dilation = dilation ), (k_img , v_img ))
171
173
k_img , v_img = map (lambda t : rearrange (t , 'b (d j) i -> b i j d' , j = kernel_size ** 2 ), (k_img , v_img ))
172
174
173
175
# let image attend to all of text
@@ -180,20 +182,19 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
180
182
i , j = dots_image .shape [- 2 :]
181
183
img_seq = torch .arange (img_seq_len , device = device )
182
184
k_img_indices = rearrange (img_seq .float (), '(h w) -> () () h w' , h = img_size )
183
- k_img_indices = F .pad (k_img_indices , ( padding ,) * 4 , value = img_seq_len ) # padding set to be max, so it is never attended to
185
+ k_img_indices = F .pad (k_img_indices , causal_padding , value = img_seq_len ) # padding set to be max, so it is never attended to
184
186
k_img_indices = F .unfold (k_img_indices , kernel_size , dilation = dilation )
185
187
k_img_indices = rearrange (k_img_indices , 'b j i -> b i j' )
186
188
187
189
# mask image attention
188
190
189
- q_img_indices = rearrange (img_seq , 'i -> () i ()' )
190
- causal_mask = q_img_indices < k_img_indices
191
+ padding_mask = k_img_indices == img_seq_len
191
192
192
193
# concat text mask with image causal mask
193
194
194
- causal_mask = repeat (causal_mask , '() i j -> b i j' , b = b * h )
195
+ padding_mask = repeat (padding_mask , '() i j -> b i j' , b = b * h )
195
196
mask = repeat (mask , 'b j -> (b h) i j' , i = i , h = h )
196
- mask = torch .cat ((~ mask , causal_mask ), dim = - 1 )
197
+ mask = torch .cat ((~ mask , padding_mask ), dim = - 1 )
197
198
198
199
# image can attend to all of text
199
200
0 commit comments