Skip to content

Commit

Permalink
adjust the code format in test_pallas.py and test_pallas_spmd.py by y…
Browse files Browse the repository at this point in the history
…apf.
  • Loading branch information
zhangp365 committed Feb 7, 2025
1 parent cb583e6 commit 5d583a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase):
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_flash_attention_wrapper_kv_and_ab_padding(self):
q = torch.randn(1, 2, 513, 4).to("xla")
k = torch.randn(1, 2, 513, 4).to("xla")
v = torch.randn(1, 2, 513, 4).to("xla")
ab = torch.randn(1,2, 513, 513).to("xla")
ab = torch.randn(1, 2, 513, 513).to("xla")

o = flash_attention(q, k, v, ab=ab)
expected_o = self._attention(q, k, v, ab=ab)
Expand Down
10 changes: 5 additions & 5 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase):
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
ab = torch.randn(4, 2, 513, 513).to("xla")

o = flash_attention(q, k, v, ab=ab, partition_spec=range(n_devices))
self.assertEqual(
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")

Expand Down

0 comments on commit 5d583a6

Please sign in to comment.