Skip to content

Commit 4f0532d

Browse files
committed
move the padding position in custom_kernel.py
1 parent 41ac7ac commit 4f0532d

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

torch_xla/experimental/custom_kernel.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,18 @@ def fa_custom_forward(
248248
full_ab = ab.clone()
249249
else:
250250
full_ab = None
251+
252+
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
253+
k.shape[2])
254+
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
255+
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
256+
if k_pad_size > 0:
257+
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
258+
if ab is None:
259+
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
260+
ab, _ = _pad_to_block_size(
261+
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
262+
251263
if partition_spec is not None:
252264
q_full_shape = q.shape
253265
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
@@ -279,17 +291,6 @@ def fa_custom_forward(
279291
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
280292
q_segment_ids, kv_segment_ids)
281293

282-
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
283-
k.shape[2])
284-
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
285-
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
286-
if k_pad_size > 0:
287-
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
288-
if ab is None:
289-
ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]))
290-
ab, _ = _pad_to_block_size(
291-
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
292-
293294
# We can't directly use flash_attention as we need to override the save_residuals flag which returns
294295
# l and m that is needed for the backward. Then we lose all the shape checks.
295296
# TODO: replicate the shape checks on flash_attention.

0 commit comments

Comments
 (0)