@@ -248,6 +248,18 @@ def fa_custom_forward(
248
248
full_ab = ab .clone ()
249
249
else :
250
250
full_ab = None
251
+
252
+ block_k_major = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k_major" ],
253
+ k .shape [2 ])
254
+ block_k = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k" ], k .shape [2 ])
255
+ k , k_pad_size = _pad_to_block_size (k , max (block_k_major , block_k ), 2 )
256
+ if k_pad_size > 0 :
257
+ v , _ = _pad_to_block_size (v , max (block_k_major , block_k ), 2 )
258
+ if ab is None :
259
+ ab = torch .zeros ((q .shape [0 ], q .shape [1 ], q .shape [2 ], q .shape [2 ]))
260
+ ab , _ = _pad_to_block_size (
261
+ ab , max (block_k_major , block_k ), 3 , padding_minus_inf = True )
262
+
251
263
if partition_spec is not None :
252
264
q_full_shape = q .shape
253
265
q = xs .enable_manual_sharding (q , partition_spec , mesh = mesh ).global_tensor
@@ -279,17 +291,6 @@ def fa_custom_forward(
279
291
segment_ids , q_segment_ids_fa , kv_segment_ids_fa = FlashAttention .prepare_segment_ids (
280
292
q_segment_ids , kv_segment_ids )
281
293
282
- block_k_major = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k_major" ],
283
- k .shape [2 ])
284
- block_k = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k" ], k .shape [2 ])
285
- k , k_pad_size = _pad_to_block_size (k , max (block_k_major , block_k ), 2 )
286
- if k_pad_size > 0 :
287
- v , _ = _pad_to_block_size (v , max (block_k_major , block_k ), 2 )
288
- if ab is None :
289
- ab = torch .zeros ((q .shape [0 ], q .shape [1 ], q .shape [2 ], q .shape [2 ]))
290
- ab , _ = _pad_to_block_size (
291
- ab , max (block_k_major , block_k ), 3 , padding_minus_inf = True )
292
-
293
294
# We can't directly use flash_attention as we need to override the save_residuals flag which returns
294
295
# l and m that is needed for the backward. Then we lose all the shape checks.
295
296
# TODO: replicate the shape checks on flash_attention.
0 commit comments