@@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
274
274
q_segment_ids , segment_id_partition_spec , mesh = mesh ).global_tensor
275
275
kv_segment_ids = xs .enable_manual_sharding (
276
276
kv_segment_ids , segment_id_partition_spec , mesh = mesh ).global_tensor
277
- segment_ids , q_segment_ids , kv_segment_ids = FlashAttention .prepare_segment_ids (
277
+ segment_ids , q_segment_ids_fa , kv_segment_ids_fa = FlashAttention .prepare_segment_ids (
278
278
q_segment_ids , kv_segment_ids )
279
279
ctx .segment_ids = segment_ids
280
280
@@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
305
305
if ab is not None :
306
306
args += [ab ]
307
307
if segment_ids is not None :
308
- args += [q_segment_ids , kv_segment_ids ]
308
+ args += [q_segment_ids_fa , kv_segment_ids_fa ]
309
309
o = torch_xla ._XLAC ._xla_tpu_custom_call (args , payload , shapes , dtypes )
310
310
311
311
if not save_residuals :
@@ -329,15 +329,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
329
329
330
330
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
331
331
# but it should be OK as the backward will use the same partition_spec
332
- ctx .save_for_backward (full_q , full_k , full_v , o , l , m , q_segment_ids ,
333
- kv_segment_ids , full_ab )
332
+ ctx .save_for_backward (full_q , full_k , full_v , o , l , m , q_segment_ids_fa ,
333
+ kv_segment_ids_fa , full_ab )
334
334
return o
335
335
336
336
@staticmethod
337
337
def backward (ctx , grad_output ):
338
338
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_bwd_dq , _flash_attention_bwd_dkv
339
339
340
- q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab = ctx .saved_tensors
340
+ q , k , v , o , l , m , q_segment_ids_fa , kv_segment_ids_fa , ab = ctx .saved_tensors
341
341
causal = ctx .causal
342
342
sm_scale = ctx .sm_scale
343
343
partition_spec = ctx .partition_spec
@@ -409,7 +409,7 @@ def backward(ctx, grad_output):
409
409
if ab is not None :
410
410
args += [ab ]
411
411
if segment_ids is not None :
412
- args += [q_segment_ids , kv_segment_ids ]
412
+ args += [q_segment_ids_fa , kv_segment_ids_fa ]
413
413
args += [expanded_l , expanded_m , grad_output , expanded_grad_i ]
414
414
415
415
outputs = [q ]
0 commit comments