Skip to content

Commit

Permalink
rename segment_id for pallas kernel to make code less confusing
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Nov 28, 2024
1 parent 5e2cb30 commit 165da7b
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
ctx.segment_ids = segment_ids

Expand Down Expand Up @@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

if not save_residuals:
Expand All @@ -327,17 +327,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
kv_segment_ids_fa, full_ab)
return o

@staticmethod
def backward(ctx, grad_output):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
Expand Down Expand Up @@ -409,7 +407,7 @@ def backward(ctx, grad_output):
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

outputs = [q]
Expand Down

0 comments on commit 165da7b

Please sign in to comment.