Skip to content

Commit 5e2cb30

Browse files
committed
lint
1 parent 0453dd1 commit 5e2cb30

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

test/test_pallas_spmd.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _attention(self, q, k, v, *, attn_mask=None, ab=None):
4040
attn_weight = attn_weight.masked_fill(attn_mask,
4141
torch.finfo(attn_weight.dtype).min)
4242
if ab is not None:
43-
attn_weight = attn_weight + ab
43+
attn_weight = attn_weight + ab
4444
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
4545
attn_output = attn_weight @ v
4646
return attn_output
@@ -139,7 +139,7 @@ def test_flash_attention_wrapper_segment_ids_spmd(self):
139139
partition_spec=("data", None, None, None))
140140
self.assertEqual(
141141
torch_xla._XLAC._get_xla_sharding_spec(o),
142-
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")
142+
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")
143143

144144
jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
145145
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
@@ -175,12 +175,19 @@ def test_flash_attention_backward_segment_ids_spmd(self):
175175
k.retain_grad()
176176
v.retain_grad()
177177

178-
o = flash_attention(q, k, v, False, segment_ids, segment_ids, partition_spec=("data", None, None, None))
178+
o = flash_attention(
179+
q,
180+
k,
181+
v,
182+
False,
183+
segment_ids,
184+
segment_ids,
185+
partition_spec=("data", None, None, None))
179186
loss = o.sum()
180187
loss.backward()
181188
q_grad = q.grad
182189
k_grad = k.grad
183-
v_grad = v.grad
190+
v_grad = v.grad
184191
self.assertEqual(
185192
torch_xla._XLAC._get_xla_sharding_spec(o),
186193
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
@@ -192,10 +199,9 @@ def test_flash_attention_backward_segment_ids_spmd(self):
192199
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
193200
self.assertEqual(
194201
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
195-
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
202+
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
196203
torch_xla.sync()
197204

198-
199205
torch.manual_seed(42)
200206
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
201207
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
@@ -220,6 +226,7 @@ def test_flash_attention_backward_segment_ids_spmd(self):
220226
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
221227
jax.config.update("jax_default_matmul_precision", "default")
222228

229+
223230
if __name__ == '__main__':
224231
logging.getLogger().setLevel(logging.INFO)
225232
torch.set_default_dtype(torch.float32)

0 commit comments

Comments
 (0)