Skip to content

Commit

Permalink
fix a padding bug in custom_kernel.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangp365 committed Feb 8, 2025
1 parent 0a639e8 commit cb2157f
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def fa_custom_forward(
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
k, k_pad_size = _pad_to_block_size(
k, max(block_k_major, block_k), 2, padding_minus_inf=True)
if k_pad_size > 0:
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
if ab is not None:
Expand Down Expand Up @@ -346,16 +347,23 @@ def fa_custom_forward(
return tuple(outs)


def _pad_to_block_size(tensor: torch.Tensor, block_size: int,
dim: int) -> Tuple[torch.Tensor, int]:
def _pad_to_block_size(
tensor: torch.Tensor,
block_size: int,
dim: int,
padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]:
size = tensor.shape[dim]
if size % block_size == 0:
return tensor, 0

pad_size = block_size - (size % block_size)
pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
padding = torch.full(
pad_shape,
torch.finfo(tensor.dtype).min if padding_minus_inf else 0,
dtype=tensor.dtype,
device=tensor.device)
padded = torch.cat([tensor, padding], dim=dim)
return padded, pad_size

Expand Down

0 comments on commit cb2157f

Please sign in to comment.