Skip to content

Commit 8b09601

Browse files
custom_kernel: fix shape mismatch by sharding segment_ids in flash attn.
when adding the sharding support in this module, seqment_ids weren't take into count which causes a failure with shape mismatch when using them in sharded flash attention.
1 parent 51575db commit 8b09601

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

torch_xla/experimental/custom_kernel.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
246246
full_k = k
247247
full_v = v
248248
full_ab = ab
249+
_, full_q_segment_ids, full_kv_segment_ids = FlashAttention.prepare_segment_ids(
250+
q_segment_ids, kv_segment_ids)
251+
249252
if partition_spec is not None:
250253
ctx.full_shape = q.shape
251254
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
@@ -254,6 +257,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
254257
if ab:
255258
ab = xs.enable_manual_sharding(
256259
ab, partition_spec, mesh=mesh).global_tensor
260+
if q_segment_ids is not None:
261+
q_segment_ids = xs.enable_manual_sharding(
262+
q_segment_ids, partition_spec[:q_segment_ids.ndim],
263+
mesh=mesh).global_tensor
264+
if kv_segment_ids is not None:
265+
kv_segment_ids = xs.enable_manual_sharding(
266+
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
267+
mesh=mesh).global_tensor
257268

258269
# It computes the shape and type of o, l, m.
259270
shapes = [q.shape]
@@ -319,8 +330,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
319330
m = xs.disable_manual_sharding(
320331
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
321332

322-
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
323-
kv_segment_ids, full_ab)
333+
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, full_q_segment_ids,
334+
full_kv_segment_ids, full_ab)
324335
return o
325336

326337
@staticmethod
@@ -363,6 +374,14 @@ def backward(ctx, grad_output):
363374
if ab:
364375
ab = xs.enable_manual_sharding(
365376
ab, partition_spec, mesh=mesh).global_tensor
377+
if q_segment_ids is not None:
378+
q_segment_ids = xs.enable_manual_sharding(
379+
q_segment_ids, partition_spec[:q_segment_ids.ndim],
380+
mesh=mesh).global_tensor
381+
if kv_segment_ids is not None:
382+
kv_segment_ids = xs.enable_manual_sharding(
383+
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
384+
mesh=mesh).global_tensor
366385

367386
if ctx.needs_input_grad[0]:
368387
payload, _ = trace_pallas(

0 commit comments

Comments
 (0)