@@ -282,7 +282,8 @@ def fa_custom_forward(
282
282
block_k_major = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k_major" ],
283
283
k .shape [2 ])
284
284
block_k = min (FlashAttention .DEFAULT_BLOCK_SIZES ["block_k" ], k .shape [2 ])
285
- k , k_pad_size = _pad_to_block_size (k , max (block_k_major , block_k ), 2 )
285
+ k , k_pad_size = _pad_to_block_size (
286
+ k , max (block_k_major , block_k ), 2 , padding_minus_inf = True )
286
287
if k_pad_size > 0 :
287
288
v , _ = _pad_to_block_size (v , max (block_k_major , block_k ), 2 )
288
289
if ab is not None :
@@ -346,16 +347,23 @@ def fa_custom_forward(
346
347
return tuple (outs )
347
348
348
349
349
- def _pad_to_block_size (tensor : torch .Tensor , block_size : int ,
350
- dim : int ) -> Tuple [torch .Tensor , int ]:
350
+ def _pad_to_block_size (
351
+ tensor : torch .Tensor ,
352
+ block_size : int ,
353
+ dim : int ,
354
+ padding_minus_inf : bool = False ) -> Tuple [torch .Tensor , int ]:
351
355
size = tensor .shape [dim ]
352
356
if size % block_size == 0 :
353
357
return tensor , 0
354
358
355
359
pad_size = block_size - (size % block_size )
356
360
pad_shape = list (tensor .shape )
357
361
pad_shape [dim ] = pad_size
358
- padding = torch .zeros (pad_shape , dtype = tensor .dtype , device = tensor .device )
362
+ padding = torch .full (
363
+ pad_shape ,
364
+ torch .finfo (tensor .dtype ).min if padding_minus_inf else 0 ,
365
+ dtype = tensor .dtype ,
366
+ device = tensor .device )
359
367
padded = torch .cat ([tensor , padding ], dim = dim )
360
368
return padded , pad_size
361
369
0 commit comments