Skip to content

Commit cb2157f

Browse files
committed
fix a padding bug in custom_kernel.py
1 parent 0a639e8 commit cb2157f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ def fa_custom_forward(
282282
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
283283
k.shape[2])
284284
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
285-
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
285+
k, k_pad_size = _pad_to_block_size(
286+
k, max(block_k_major, block_k), 2, padding_minus_inf=True)
286287
if k_pad_size > 0:
287288
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
288289
if ab is not None:
@@ -346,16 +347,23 @@ def fa_custom_forward(
346347
return tuple(outs)
347348

348349

349-
def _pad_to_block_size(tensor: torch.Tensor, block_size: int,
350-
dim: int) -> Tuple[torch.Tensor, int]:
350+
def _pad_to_block_size(
351+
tensor: torch.Tensor,
352+
block_size: int,
353+
dim: int,
354+
padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]:
351355
size = tensor.shape[dim]
352356
if size % block_size == 0:
353357
return tensor, 0
354358

355359
pad_size = block_size - (size % block_size)
356360
pad_shape = list(tensor.shape)
357361
pad_shape[dim] = pad_size
358-
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
362+
padding = torch.full(
363+
pad_shape,
364+
torch.finfo(tensor.dtype).min if padding_minus_inf else 0,
365+
dtype=tensor.dtype,
366+
device=tensor.device)
359367
padded = torch.cat([tensor, padding], dim=dim)
360368
return padded, pad_size
361369

0 commit comments

Comments
 (0)