Skip to content

Commit

Permalink
Support SegmentID when doing data prallel SPMD
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Nov 27, 2024
1 parent 20f5166 commit 0453dd1
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 1 deletion.
123 changes: 122 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
import numpy as np
from torch import nn as nn

import torch_xla
Expand All @@ -22,8 +23,24 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +115,110 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_spmd(self):
from torch_xla.experimental.custom_kernel import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
xs.set_global_mesh(xs.get_1d_mesh("data"))

q = torch.randn(3, 2, 128, 4)
k = torch.randn(3, 2, 128, 4)
v = torch.randn(3, 2, 128, 4)
zeros = torch.zeros(3, 32)
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
segment_ids_xla = segment_ids.to("xla")
# only shard data dimension
o = flash_attention(
q.to("xla"),
k.to("xla"),
v.to("xla"),
False,
segment_ids_xla,
segment_ids.to("xla"),
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32)
expected_o = torch.from_numpy(
np.array(
jax_flash_attention(
jax_q,
jax_k,
jax_v,
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
)))

self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.get_1d_mesh("data"))

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v, False, segment_ids, segment_ids, partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()


torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(
q,
k,
v,
attn_mask=self._make_attention_mask_from_segment_ids(
segment_ids, segment_ids))
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
dtypes.append(torch.float32)

with torch.no_grad():
if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
# partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
# is of shape [batch, seq_len], hence we need to tweak it a bit
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
ctx.segment_ids = segment_ids
Expand Down Expand Up @@ -319,6 +327,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
return o
Expand All @@ -333,6 +343,7 @@ def backward(ctx, grad_output):
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
# this segment_ids only reflects the local shape of segment_ids
segment_ids = ctx.segment_ids
grad_q = grad_k = grad_v = grad_ab = None

Expand Down

0 comments on commit 0453dd1

Please sign in to comment.