@@ -246,6 +246,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
246
246
full_k = k
247
247
full_v = v
248
248
full_ab = ab
249
+ _ , full_q_segment_ids , full_kv_segment_ids = FlashAttention .prepare_segment_ids (
250
+ q_segment_ids , kv_segment_ids )
251
+
249
252
if partition_spec is not None :
250
253
ctx .full_shape = q .shape
251
254
q = xs .enable_manual_sharding (q , partition_spec , mesh = mesh ).global_tensor
@@ -254,6 +257,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
254
257
if ab :
255
258
ab = xs .enable_manual_sharding (
256
259
ab , partition_spec , mesh = mesh ).global_tensor
260
+ if q_segment_ids is not None :
261
+ q_segment_ids = xs .enable_manual_sharding (
262
+ q_segment_ids , partition_spec [:q_segment_ids .ndim ],
263
+ mesh = mesh ).global_tensor
264
+ if kv_segment_ids is not None :
265
+ kv_segment_ids = xs .enable_manual_sharding (
266
+ kv_segment_ids , partition_spec [:kv_segment_ids .ndim ],
267
+ mesh = mesh ).global_tensor
257
268
258
269
# It computes the shape and type of o, l, m.
259
270
shapes = [q .shape ]
@@ -319,8 +330,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
319
330
m = xs .disable_manual_sharding (
320
331
m , partition_spec [0 :3 ], ctx .full_shape [0 :3 ], mesh = mesh ).global_tensor
321
332
322
- ctx .save_for_backward (full_q , full_k , full_v , o , l , m , q_segment_ids ,
323
- kv_segment_ids , full_ab )
333
+ ctx .save_for_backward (full_q , full_k , full_v , o , l , m , full_q_segment_ids ,
334
+ full_kv_segment_ids , full_ab )
324
335
return o
325
336
326
337
@staticmethod
@@ -363,6 +374,14 @@ def backward(ctx, grad_output):
363
374
if ab :
364
375
ab = xs .enable_manual_sharding (
365
376
ab , partition_spec , mesh = mesh ).global_tensor
377
+ if q_segment_ids is not None :
378
+ q_segment_ids = xs .enable_manual_sharding (
379
+ q_segment_ids , partition_spec [:q_segment_ids .ndim ],
380
+ mesh = mesh ).global_tensor
381
+ if kv_segment_ids is not None :
382
+ kv_segment_ids = xs .enable_manual_sharding (
383
+ kv_segment_ids , partition_spec [:kv_segment_ids .ndim ],
384
+ mesh = mesh ).global_tensor
366
385
367
386
if ctx .needs_input_grad [0 ]:
368
387
payload , _ = trace_pallas (
0 commit comments