Skip to content

Commit b5d1b8f

Browse files
test proposel for segment ids that fails.
1 parent 8b09601 commit b5d1b8f

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

test/test_pallas_spmd.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222

2323
class PallasTest(unittest.TestCase):
2424

25-
def _attention(self, q, k, v):
25+
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
2626
attn_weight = q @ k.transpose(-2, -1)
27+
if attn_mask is not None:
28+
# Masked out the unrelevant parts.
29+
attn_weight = attn_weight.masked_fill(attn_mask,
30+
torch.finfo(attn_weight.dtype).min)
31+
if ab is not None:
32+
attn_weight = attn_weight + ab
2733
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
2834
attn_output = attn_weight @ v
2935
return attn_output
@@ -98,6 +104,36 @@ def test_flash_attention_backward_spmd_data_parallel(self):
98104
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
99105
jax.config.update('jax_default_matmul_precision', "default")
100106

107+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
108+
"This test only works on TPUv3+.")
109+
def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
110+
jax.config.update('jax_default_matmul_precision', "highest")
111+
n_devices = xr.global_runtime_device_count()
112+
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))
113+
114+
q = torch.randn(16, 32, 2048, 64).to("xla")
115+
k = torch.randn(16, 32, 128, 64).to("xla")
116+
v = torch.randn(16, 32, 128, 64).to("xla")
117+
q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla")
118+
kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla")
119+
# convert mask into a bias that can be added to attention scores:
120+
# (keep = +0, discard = -10000.0)
121+
kv_segment_ids[:8, :, 30:] = -10000.0
122+
kv_segment_ids[8:, :, 60:] = -10000.0
123+
124+
o = flash_attention(
125+
q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(n_devices))
126+
self.assertEqual(
127+
torch_xla._XLAC._get_xla_sharding_spec(o),
128+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
129+
130+
attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
131+
attention_mask = attention_mask.view(16, 32, 1, 128)
132+
133+
expected_o = self._attention(q, k, v, attn_mask=attention_mask)
134+
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
135+
jax.config.update('jax_default_matmul_precision', "default")
136+
101137

102138
if __name__ == '__main__':
103139
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)