Skip to content

Commit c00bf53

Browse files
JackCaoGrpsilva-aws
authored andcommitted
Support SegmentID when doing data prallel SPMD (pytorch#8425)
1 parent f3dadf3 commit c00bf53

File tree

2 files changed

+146
-7
lines changed

2 files changed

+146
-7
lines changed

test/test_pallas_spmd.py

+129-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44

55
import torch
6+
import numpy as np
67
from torch import nn as nn
78

89
import torch_xla
@@ -22,8 +23,24 @@
2223

2324
class PallasTest(unittest.TestCase):
2425

25-
def _attention(self, q, k, v):
26+
# This is to create a diagonal mask where only elements within the same segment
27+
# can attend to each other. Since the mask is to mask out the unrelevant parts,
28+
# therefore we use != instead of ==.
29+
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
30+
kv_segment_ids):
31+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
32+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
33+
kv_segment_ids.shape[0], 1, 1,
34+
kv_segment_ids.shape[1])
35+
36+
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
2637
attn_weight = q @ k.transpose(-2, -1)
38+
if attn_mask is not None:
39+
# Masked out the unrelevant parts.
40+
attn_weight = attn_weight.masked_fill(attn_mask,
41+
torch.finfo(attn_weight.dtype).min)
42+
if ab is not None:
43+
attn_weight = attn_weight + ab
2744
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
2845
attn_output = attn_weight @ v
2946
return attn_output
@@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self):
98115
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
99116
jax.config.update('jax_default_matmul_precision', "default")
100117

118+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
119+
"This test only works on TPUv3+.")
120+
def test_flash_attention_wrapper_segment_ids_spmd(self):
121+
from torch_xla.experimental.custom_kernel import flash_attention
122+
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
123+
xs.set_global_mesh(xs.get_1d_mesh("data"))
124+
125+
q = torch.randn(3, 2, 128, 4)
126+
k = torch.randn(3, 2, 128, 4)
127+
v = torch.randn(3, 2, 128, 4)
128+
zeros = torch.zeros(3, 32)
129+
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
130+
segment_ids_xla = segment_ids.to("xla")
131+
# only shard data dimension
132+
o = flash_attention(
133+
q.to("xla"),
134+
k.to("xla"),
135+
v.to("xla"),
136+
False,
137+
segment_ids_xla,
138+
segment_ids.to("xla"),
139+
partition_spec=("data", None, None, None))
140+
self.assertEqual(
141+
torch_xla._XLAC._get_xla_sharding_spec(o),
142+
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")
143+
144+
jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
145+
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
146+
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
147+
jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32)
148+
expected_o = torch.from_numpy(
149+
np.array(
150+
jax_flash_attention(
151+
jax_q,
152+
jax_k,
153+
jax_v,
154+
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
155+
)))
156+
157+
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
158+
jax.config.update('jax_default_matmul_precision', "default")
159+
160+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
161+
"This test only works on TPUv3+.")
162+
def test_flash_attention_backward_segment_ids_spmd(self):
163+
jax.config.update("jax_default_matmul_precision", "highest")
164+
from torch_xla.experimental.custom_kernel import flash_attention
165+
n_devices = xr.global_runtime_device_count()
166+
xs.set_global_mesh(xs.get_1d_mesh("data"))
167+
168+
torch.manual_seed(42)
169+
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
170+
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
171+
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
172+
zeros = torch.zeros(4, 32).to("xla")
173+
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
174+
q.retain_grad()
175+
k.retain_grad()
176+
v.retain_grad()
177+
178+
o = flash_attention(
179+
q,
180+
k,
181+
v,
182+
False,
183+
segment_ids,
184+
segment_ids,
185+
partition_spec=("data", None, None, None))
186+
loss = o.sum()
187+
loss.backward()
188+
q_grad = q.grad
189+
k_grad = k.grad
190+
v_grad = v.grad
191+
self.assertEqual(
192+
torch_xla._XLAC._get_xla_sharding_spec(o),
193+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
194+
self.assertEqual(
195+
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
196+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
197+
self.assertEqual(
198+
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
199+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
200+
self.assertEqual(
201+
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
202+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
203+
torch_xla.sync()
204+
205+
torch.manual_seed(42)
206+
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
207+
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
208+
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
209+
zeros = torch.zeros(4, 32).to("xla")
210+
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
211+
q.retain_grad()
212+
k.retain_grad()
213+
v.retain_grad()
214+
215+
o = self._attention(
216+
q,
217+
k,
218+
v,
219+
attn_mask=self._make_attention_mask_from_segment_ids(
220+
segment_ids, segment_ids))
221+
loss = o.sum()
222+
loss.backward()
223+
xm.mark_step()
224+
225+
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
226+
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
227+
jax.config.update("jax_default_matmul_precision", "default")
228+
101229

102230
if __name__ == '__main__':
103231
logging.getLogger().setLevel(logging.INFO)

torch_xla/experimental/custom_kernel.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
266266
dtypes.append(torch.float32)
267267

268268
with torch.no_grad():
269-
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
269+
if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
270+
# partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
271+
# is of shape [batch, seq_len], hence we need to tweak it a bit
272+
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
273+
q_segment_ids = xs.enable_manual_sharding(
274+
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
275+
kv_segment_ids = xs.enable_manual_sharding(
276+
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
277+
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
270278
q_segment_ids, kv_segment_ids)
271279
ctx.segment_ids = segment_ids
272280

@@ -297,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
297305
if ab is not None:
298306
args += [ab]
299307
if segment_ids is not None:
300-
args += [q_segment_ids, kv_segment_ids]
308+
args += [q_segment_ids_fa, kv_segment_ids_fa]
301309
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)
302310

303311
if not save_residuals:
@@ -319,20 +327,23 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
319327
m = xs.disable_manual_sharding(
320328
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
321329

322-
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
323-
kv_segment_ids, full_ab)
330+
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
331+
# but it should be OK as the backward will use the same partition_spec
332+
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
333+
kv_segment_ids_fa, full_ab)
324334
return o
325335

326336
@staticmethod
327337
def backward(ctx, grad_output):
328338
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
329339

330-
q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
340+
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
331341
causal = ctx.causal
332342
sm_scale = ctx.sm_scale
333343
partition_spec = ctx.partition_spec
334344
mesh = ctx.mesh
335345
full_shape = ctx.full_shape
346+
# this segment_ids only reflects the local shape of segment_ids
336347
segment_ids = ctx.segment_ids
337348
grad_q = grad_k = grad_v = grad_ab = None
338349

@@ -398,7 +409,7 @@ def backward(ctx, grad_output):
398409
if ab is not None:
399410
args += [ab]
400411
if segment_ids is not None:
401-
args += [q_segment_ids, kv_segment_ids]
412+
args += [q_segment_ids_fa, kv_segment_ids_fa]
402413
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]
403414

404415
outputs = [q]

0 commit comments

Comments
 (0)