@@ -248,6 +248,18 @@ def fa_custom_forward(
248
248
full_ab = ab .clone ()
249
249
else :
250
250
full_ab = None
251
+
252
+ block_k_major = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k_major" ],
253
+ k .shape [2 ])
254
+ block_k = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k" ], k .shape [2 ])
255
+ k , k_pad_size = _pad_to_block_size (k , max (block_k_major , block_k ), 2 )
256
+ if k_pad_size > 0 :
257
+ v , _ = _pad_to_block_size (v , max (block_k_major , block_k ), 2 )
258
+ if ab is None :
259
+ ab = torch .zeros ((q .shape [0 ], q .shape [1 ], q .shape [2 ], q .shape [2 ]))
260
+ ab , _ = _pad_to_block_size (
261
+ ab , max (block_k_major , block_k ), 3 , padding_minus_inf = True )
262
+
251
263
if partition_spec is not None :
252
264
q_full_shape = q .shape
253
265
q = xs .enable_manual_sharding (q , partition_spec , mesh = mesh ).global_tensor
@@ -295,8 +307,8 @@ def fa_custom_forward(
295
307
sm_scale ,
296
308
min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_b" ], q .shape [0 ]),
297
309
min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_q" ], q .shape [2 ]),
298
- min ( FlashAttention . DEFAULT_BLOCK_SIZES [ " block_k_major" ], k . shape [ 2 ]) ,
299
- min ( FlashAttention . DEFAULT_BLOCK_SIZES [ " block_k" ], k . shape [ 2 ]) ,
310
+ block_k_major ,
311
+ block_k ,
300
312
False ,
301
313
static_argnums = range (5 , 13 ),
302
314
use_cache = True ,
@@ -337,6 +349,27 @@ def fa_custom_forward(
337
349
return tuple (outs )
338
350
339
351
352
+ def _pad_to_block_size (
353
+ tensor : torch .Tensor ,
354
+ block_size : int ,
355
+ dim : int ,
356
+ padding_minus_inf : bool = False ) -> Tuple [torch .Tensor , int ]:
357
+ size = tensor .shape [dim ]
358
+ if size % block_size == 0 :
359
+ return tensor , 0
360
+
361
+ pad_size = block_size - (size % block_size )
362
+ pad_shape = list (tensor .shape )
363
+ pad_shape [dim ] = pad_size
364
+ padding = torch .full (
365
+ pad_shape ,
366
+ torch .finfo (tensor .dtype ).min if padding_minus_inf else 0 ,
367
+ dtype = tensor .dtype ,
368
+ device = tensor .device )
369
+ padded = torch .cat ([tensor , padding ], dim = dim )
370
+ return padded , pad_size
371
+
372
+
340
373
@custom_op ("xla::fa_custom_backward" , mutates_args = ())
341
374
def fa_custom_backward (
342
375
grad_output : torch .Tensor , q : torch .Tensor , k : torch .Tensor ,
0 commit comments