Skip to content

Commit 657f466

Browse files
authored
use sdpa and exportable functions in transformer multi head attention (#1760)
1 parent c7b0300 commit 657f466

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

python/mlx/nn/layers/transformer.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,15 @@ def __call__(self, queries, keys, values, mask=None):
8282
values = self.value_proj(values)
8383

8484
num_heads = self.num_heads
85-
B, L, D = queries.shape
86-
_, S, _ = keys.shape
87-
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
88-
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
89-
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
90-
91-
# Dimensions are [batch x num heads x sequence x hidden dim]
85+
queries = mx.unflatten(queries, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
86+
keys = mx.unflatten(keys, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
87+
values = mx.unflatten(values, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
9288
scale = math.sqrt(1 / queries.shape[-1])
93-
scores = (queries * scale) @ keys
94-
if mask is not None:
95-
scores = scores + mask.astype(scores.dtype)
96-
scores = mx.softmax(scores, axis=-1)
97-
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
98-
99-
return self.out_proj(values_hat)
89+
output = mx.fast.scaled_dot_product_attention(
90+
queries, keys, values, scale=scale, mask=mask
91+
)
92+
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
93+
return self.out_proj(output)
10094

10195
@staticmethod
10296
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):

python/tests/test_nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,12 @@ def test_causal_mask(self):
18351835
self.assertFalse(mx.any(mx.isnan(mask)))
18361836
self.assertTrue(mask[0, -1].item() < 0)
18371837

1838+
def test_attention(self):
1839+
attn = nn.MultiHeadAttention(32, 4)
1840+
x = mx.random.normal(shape=(2, 5, 32))
1841+
out = attn(x, x, x)
1842+
self.assertEqual(out.shape, x.shape)
1843+
18381844

18391845
if __name__ == "__main__":
18401846
unittest.main()

0 commit comments

Comments
 (0)