Skip to content

Commit

Permalink
adjust the code format in test_pallas.py and test_pallas_spmd.py by y…
Browse files Browse the repository at this point in the history
…apf 0.30.0.
  • Loading branch information
zhangp365 committed Feb 7, 2025
1 parent 5d583a6 commit 0a639e8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase):
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
1, 1,
kv_segment_ids.shape[1])
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
Expand Down
8 changes: 4 additions & 4 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase):
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
1, 1,
kv_segment_ids.shape[1])
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
Expand Down

0 comments on commit 0a639e8

Please sign in to comment.