@@ -274,7 +274,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
274
274
q_segment_ids , segment_id_partition_spec , mesh = mesh ).global_tensor
275
275
kv_segment_ids = xs .enable_manual_sharding (
276
276
kv_segment_ids , segment_id_partition_spec , mesh = mesh ).global_tensor
277
- segment_ids , q_segment_ids , kv_segment_ids = FlashAttention .prepare_segment_ids (
277
+ segment_ids , q_segment_ids_fa , kv_segment_ids_fa = FlashAttention .prepare_segment_ids (
278
278
q_segment_ids , kv_segment_ids )
279
279
ctx .segment_ids = segment_ids
280
280
@@ -305,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
305
305
if ab is not None :
306
306
args += [ab ]
307
307
if segment_ids is not None :
308
- args += [q_segment_ids , kv_segment_ids ]
308
+ args += [q_segment_ids_fa , kv_segment_ids_fa ]
309
309
o = torch_xla ._XLAC ._xla_tpu_custom_call (args , payload , shapes , dtypes )
310
310
311
311
if not save_residuals :
@@ -327,17 +327,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
327
327
m = xs .disable_manual_sharding (
328
328
m , partition_spec [0 :3 ], ctx .full_shape [0 :3 ], mesh = mesh ).global_tensor
329
329
330
- # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
331
- # but it should be OK as the backward will use the same partition_spec
332
- ctx .save_for_backward (full_q , full_k , full_v , o , l , m , q_segment_ids ,
333
- kv_segment_ids , full_ab )
330
+ ctx .save_for_backward (full_q , full_k , full_v , o , l , m , q_segment_ids_fa ,
331
+ kv_segment_ids_fa , full_ab )
334
332
return o
335
333
336
334
@staticmethod
337
335
def backward (ctx , grad_output ):
338
336
from jax .experimental .pallas .ops .tpu .flash_attention import _flash_attention_bwd_dq , _flash_attention_bwd_dkv
339
337
340
- q , k , v , o , l , m , q_segment_ids , kv_segment_ids , ab = ctx .saved_tensors
338
+ q , k , v , o , l , m , q_segment_ids_fa , kv_segment_ids_fa , ab = ctx .saved_tensors
341
339
causal = ctx .causal
342
340
sm_scale = ctx .sm_scale
343
341
partition_spec = ctx .partition_spec
@@ -409,7 +407,7 @@ def backward(ctx, grad_output):
409
407
if ab is not None :
410
408
args += [ab ]
411
409
if segment_ids is not None :
412
- args += [q_segment_ids , kv_segment_ids ]
410
+ args += [q_segment_ids_fa , kv_segment_ids_fa ]
413
411
args += [expanded_l , expanded_m , grad_output , expanded_grad_i ]
414
412
415
413
outputs = [q ]
0 commit comments