Skip to content

Commit c4db598

Browse files
committed
rename segment_id for pallas kernel to make code less confusing
1 parent 5e2cb30 commit c4db598

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torch_xla/experimental/custom_kernel.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
274274
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
275275
kv_segment_ids = xs.enable_manual_sharding(
276276
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
277-
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
277+
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
278278
q_segment_ids, kv_segment_ids)
279279
ctx.segment_ids = segment_ids
280280

@@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
305305
if ab is not None:
306306
args += [ab]
307307
if segment_ids is not None:
308-
args += [q_segment_ids, kv_segment_ids]
308+
args += [q_segment_ids_fa, kv_segment_ids_fa]
309309
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)
310310

311311
if not save_residuals:
@@ -329,15 +329,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
329329

330330
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
331331
# but it should be OK as the backward will use the same partition_spec
332-
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
333-
kv_segment_ids, full_ab)
332+
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
333+
kv_segment_ids_fa, full_ab)
334334
return o
335335

336336
@staticmethod
337337
def backward(ctx, grad_output):
338338
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
339339

340-
q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
340+
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
341341
causal = ctx.causal
342342
sm_scale = ctx.sm_scale
343343
partition_spec = ctx.partition_spec
@@ -409,7 +409,7 @@ def backward(ctx, grad_output):
409409
if ab is not None:
410410
args += [ab]
411411
if segment_ids is not None:
412-
args += [q_segment_ids, kv_segment_ids]
412+
args += [q_segment_ids_fa, kv_segment_ids_fa]
413413
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]
414414

415415
outputs = [q]

0 commit comments

Comments
 (0)