Skip to content

Commit 4ba9067

Browse files
fdf
1 parent a3fb4e1 commit 4ba9067

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

test/test_pallas_spmd.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -112,34 +112,31 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
112112
n_devices = xr.global_runtime_device_count()
113113
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))
114114

115-
q = torch.randn(4, 2, 128, 4).to("xla")
116-
k = torch.randn(4, 2, 128, 4).to("xla")
117-
v = torch.randn(4, 2, 128, 4).to("xla")
118-
q_segment_ids = torch.ones(4, 128, device=q.device, dtype=torch.float32).to("xla")
119-
kv_segment_ids = torch.rand(4, 128).to("xla")
115+
q = torch.randn(16, 32, 2048, 64).to("xla")
116+
k = torch.randn(16, 32, 128, 64).to("xla")
117+
v = torch.randn(16, 32, 128, 64).to("xla")
118+
q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla")
119+
kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla")
120+
kv_segment_ids[:8, :, 30:] = -10000.0
121+
kv_segment_ids[8:, :, 60:] = -10000.0
120122

121123
o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4))
122124
self.assertEqual(
123125
torch_xla._XLAC._get_xla_sharding_spec(o),
124126
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")
125127

126-
attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0)
127-
attention_mask = attention_mask.repeat_interleave(2, dim=0)
128-
attention_mask = attention_mask.view(4, 2, 128, 128)
129-
# attention_mask = torch.ones(4, 2, 128, 128).to("xla")
130-
# head_size = self.heads
131-
# current_length: int = attention_mask.shape[-1]
132-
# if current_length != target_length:
133-
# attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
134-
135-
# if attention_mask.shape[0] < 4 * head_size:
136-
# attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
137-
#
138-
# attention_mask = attention_mask.view(
139-
# batch_size, attn.heads, -1, attention_mask.shape[-1]
140-
# )
128+
attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
129+
attention_mask = attention_mask.view(16, 32, 1, 128)
141130

142131
expected_o = self._attention(q, k, v, attn_mask=attention_mask)
132+
# expected_o = F.scaled_dot_product_attention(
133+
# q,
134+
# k,
135+
# v,
136+
# attn_mask=attention_mask,
137+
# dropout_p=0.0,
138+
# is_causal=False,
139+
# )
143140
diff = (expected_o - o).abs()
144141
# z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)
145142

0 commit comments

Comments
 (0)