Skip to content

Commit 165da7b

Browse files
committed
rename segment_id for pallas kernel to make code less confusing
1 parent 5e2cb30 commit 165da7b

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

torch_xla/experimental/custom_kernel.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
274274
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
275275
kv_segment_ids = xs.enable_manual_sharding(
276276
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
277-
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
277+
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
278278
q_segment_ids, kv_segment_ids)
279279
ctx.segment_ids = segment_ids
280280

@@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
305305
if ab is not None:
306306
args += [ab]
307307
if segment_ids is not None:
308-
args += [q_segment_ids, kv_segment_ids]
308+
args += [q_segment_ids_fa, kv_segment_ids_fa]
309309
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)
310310

311311
if not save_residuals:
@@ -327,17 +327,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
327327
m = xs.disable_manual_sharding(
328328
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
329329

330-
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
331-
# but it should be OK as the backward will use the same partition_spec
332-
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
333-
kv_segment_ids, full_ab)
330+
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
331+
kv_segment_ids_fa, full_ab)
334332
return o
335333

336334
@staticmethod
337335
def backward(ctx, grad_output):
338336
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
339337

340-
q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
338+
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
341339
causal = ctx.causal
342340
sm_scale = ctx.sm_scale
343341
partition_spec = ctx.partition_spec
@@ -409,7 +407,7 @@ def backward(ctx, grad_output):
409407
if ab is not None:
410408
args += [ab]
411409
if segment_ids is not None:
412-
args += [q_segment_ids, kv_segment_ids]
410+
args += [q_segment_ids_fa, kv_segment_ids_fa]
413411
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]
414412

415413
outputs = [q]

0 commit comments

Comments
 (0)