Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom_kernel: fix shape mismatch by sharding segment_ids in flash attn. #8333

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +104,36 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
jax.config.update('jax_default_matmul_precision', "highest")
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))

q = torch.randn(16, 32, 2048, 64).to("xla")
k = torch.randn(16, 32, 128, 64).to("xla")
v = torch.randn(16, 32, 128, 64).to("xla")
q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla")
kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla")
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
kv_segment_ids[:8, :, 30:] = -10000.0
kv_segment_ids[8:, :, 60:] = -10000.0

o = flash_attention(
q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")

attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(16, 32, 1, 128)

expected_o = self._attention(q, k, v, attn_mask=attention_mask)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
23 changes: 21 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
full_k = k
full_v = v
full_ab = ab
_, full_q_segment_ids, full_kv_segment_ids = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

if partition_spec is not None:
ctx.full_shape = q.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
Expand All @@ -254,6 +257,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor
if q_segment_ids is not None:
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor

# It computes the shape and type of o, l, m.
shapes = [q.shape]
Expand Down Expand Up @@ -319,8 +330,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, full_q_segment_ids,
full_kv_segment_ids, full_ab)
return o

@staticmethod
Expand Down Expand Up @@ -363,6 +374,14 @@ def backward(ctx, grad_output):
if ab:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor
if q_segment_ids is not None:
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor

if ctx.needs_input_grad[0]:
payload, _ = trace_pallas(
Expand Down