6
6
from typing import Any
7
7
from .embedding import Embedding
8
8
from .quantize import dequantize_linear
9
+ from .kv_cache import TinyKvCache
9
10
10
11
11
12
class Qwen2MultiHeadAttention :
@@ -48,6 +49,7 @@ def __call__(
48
49
self ,
49
50
x : mx .array ,
50
51
offset : int ,
52
+ cache : TinyKvCache ,
51
53
) -> mx .array :
52
54
B , L , _ = x .shape
53
55
orig_dtype = x .dtype
@@ -66,19 +68,12 @@ def __call__(
66
68
.reshape (B , L , self .num_kv_heads , self .head_dim )
67
69
.astype (mx .float32 )
68
70
)
69
- # offset = cache.offset
70
71
projection_q = self .rope (projection_q , offset = slice (offset , offset + L ))
71
72
projection_k = self .rope (projection_k , offset = slice (offset , offset + L ))
72
73
projection_q = projection_q .transpose (0 , 2 , 1 , 3 )
73
74
projection_k = projection_k .transpose (0 , 2 , 1 , 3 )
74
75
projection_v = projection_v .transpose (0 , 2 , 1 , 3 )
75
- # TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1.
76
- # mlx-lm's KvCache seems to do more than just caching, we could extract something out of it.
77
- # projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
78
- assert (
79
- projection_k .dtype == mx .float32
80
- ) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.
81
- assert projection_v .dtype == mx .float32
76
+ projection_k , projection_v = cache .update_and_fetch (projection_k , projection_v )
82
77
x = scaled_dot_product_attention_grouped (
83
78
projection_q ,
84
79
projection_k ,
@@ -157,8 +152,9 @@ def __call__(
157
152
self ,
158
153
x : mx .array ,
159
154
offset : int ,
155
+ cache : TinyKvCache ,
160
156
) -> mx .array :
161
- r = self .self_attn (self .input_layernorm (x ), offset )
157
+ r = self .self_attn (self .input_layernorm (x ), offset , cache )
162
158
h = x + r
163
159
r = self .mlp (self .post_attention_layernorm (h ))
164
160
out = h + r
@@ -230,9 +226,10 @@ def __call__(
230
226
self ,
231
227
inputs : mx .array ,
232
228
offset : int ,
229
+ cache : list [TinyKvCache ],
233
230
) -> mx .array :
234
231
h = self .embedding (inputs )
235
232
for layer in range (self .num_hidden_layers ):
236
- h = self .layers_inner [layer ](h , offset )
233
+ h = self .layers_inner [layer ](h , offset , cache [ layer ] )
237
234
h = self .norm (h )
238
235
return linear (h , self .w_lm_head )
0 commit comments