Skip to content

Commit

Permalink
move the padding position in custom_kernel.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangp365 committed Feb 8, 2025
1 parent 41ac7ac commit 4f0532d
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def fa_custom_forward(
full_ab = ab.clone()
else:
full_ab = None

block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
if k_pad_size > 0:
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
if ab is None:
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
ab, _ = _pad_to_block_size(
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)

if partition_spec is not None:
q_full_shape = q.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
Expand Down Expand Up @@ -279,17 +291,6 @@ def fa_custom_forward(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
if k_pad_size > 0:
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
if ab is None:
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
ab, _ = _pad_to_block_size(
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)

# We can't directly use flash_attention as we need to override the save_residuals flag which returns
# l and m that is needed for the backward. Then we lose all the shape checks.
# TODO: replicate the shape checks on flash_attention.
Expand Down

0 comments on commit 4f0532d

Please sign in to comment.