From cb2157f05f39d246d3d9845dc8afac6aae77750d Mon Sep 17 00:00:00 2001 From: zhangsongbo Date: Sat, 8 Feb 2025 15:25:01 +0800 Subject: [PATCH] fix a padding bug in custom_kernel.py --- torch_xla/experimental/custom_kernel.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d36f0c1a9d4..c6ec8d8773f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -282,7 +282,8 @@ def fa_custom_forward( block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]) block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]) - k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2) + k, k_pad_size = _pad_to_block_size( + k, max(block_k_major, block_k), 2, padding_minus_inf=True) if k_pad_size > 0: v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2) if ab is not None: @@ -346,8 +347,11 @@ def fa_custom_forward( return tuple(outs) -def _pad_to_block_size(tensor: torch.Tensor, block_size: int, - dim: int) -> Tuple[torch.Tensor, int]: +def _pad_to_block_size( + tensor: torch.Tensor, + block_size: int, + dim: int, + padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]: size = tensor.shape[dim] if size % block_size == 0: return tensor, 0 @@ -355,7 +359,11 @@ def _pad_to_block_size(tensor: torch.Tensor, block_size: int, pad_size = block_size - (size % block_size) pad_shape = list(tensor.shape) pad_shape[dim] = pad_size - padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + padding = torch.full( + pad_shape, + torch.finfo(tensor.dtype).min if padding_minus_inf else 0, + dtype=tensor.dtype, + device=tensor.device) padded = torch.cat([tensor, padding], dim=dim) return padded, pad_size