@@ -82,21 +82,15 @@ def __call__(self, queries, keys, values, mask=None):
8282 values = self .value_proj (values )
8383
8484 num_heads = self .num_heads
85- B , L , D = queries .shape
86- _ , S , _ = keys .shape
87- queries = queries .reshape (B , L , num_heads , - 1 ).transpose (0 , 2 , 1 , 3 )
88- keys = keys .reshape (B , S , num_heads , - 1 ).transpose (0 , 2 , 3 , 1 )
89- values = values .reshape (B , S , num_heads , - 1 ).transpose (0 , 2 , 1 , 3 )
90-
91- # Dimensions are [batch x num heads x sequence x hidden dim]
85+ queries = mx .unflatten (queries , - 1 , (num_heads , - 1 )).transpose (0 , 2 , 1 , 3 )
86+ keys = mx .unflatten (keys , - 1 , (num_heads , - 1 )).transpose (0 , 2 , 1 , 3 )
87+ values = mx .unflatten (values , - 1 , (num_heads , - 1 )).transpose (0 , 2 , 1 , 3 )
9288 scale = math .sqrt (1 / queries .shape [- 1 ])
93- scores = (queries * scale ) @ keys
94- if mask is not None :
95- scores = scores + mask .astype (scores .dtype )
96- scores = mx .softmax (scores , axis = - 1 )
97- values_hat = (scores @ values ).transpose (0 , 2 , 1 , 3 ).reshape (B , L , - 1 )
98-
99- return self .out_proj (values_hat )
89+ output = mx .fast .scaled_dot_product_attention (
90+ queries , keys , values , scale = scale , mask = mask
91+ )
92+ output = output .transpose (0 , 2 , 1 , 3 ).flatten (- 2 , - 1 )
93+ return self .out_proj (output )
10094
10195 @staticmethod
10296 def create_additive_causal_mask (N : int , dtype : mx .Dtype = mx .float32 ):
0 commit comments