From 165da7b03a41b4fcb55d4b99d62f04668479580c Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Thu, 28 Nov 2024 00:01:58 +0000 Subject: [PATCH] rename segment_id for pallas kernel to make code less confusing --- torch_xla/experimental/custom_kernel.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6fd171ea6dab..9ffa99f1772c 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor kv_segment_ids = xs.enable_manual_sharding( kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor - segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( q_segment_ids, kv_segment_ids) ctx.segment_ids = segment_ids @@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) if not save_residuals: @@ -327,17 +327,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, m = xs.disable_manual_sharding( m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor - # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided - # but it should be OK as the backward will use the same partition_spec - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, - kv_segment_ids, full_ab) + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa, + kv_segment_ids_fa, full_ab) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors causal = ctx.causal sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec @@ -409,7 +407,7 @@ def backward(ctx, grad_output): if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] args += [expanded_l, expanded_m, grad_output, expanded_grad_i] outputs = [q]