Skip to content

Commit 0a639e8

Browse files
committed
adjust the code format in test_pallas.py and test_pallas_spmd.py by yapf 0.30.0.
1 parent 5d583a6 commit 0a639e8

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

test/test_pallas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45-
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46-
1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46+
kv_segment_ids.shape[0], 1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

test/test_pallas_spmd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45-
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46-
1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46+
kv_segment_ids.shape[0], 1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

0 commit comments

Comments
 (0)