66from typing import Any
77from .embedding import Embedding
88from .quantize import dequantize_linear
9+ from .kv_cache import TinyKvCache
910
1011
1112class Qwen2MultiHeadAttention :
@@ -48,6 +49,7 @@ def __call__(
4849 self ,
4950 x : mx .array ,
5051 offset : int ,
52+ cache : TinyKvCache ,
5153 ) -> mx .array :
5254 B , L , _ = x .shape
5355 orig_dtype = x .dtype
@@ -66,19 +68,12 @@ def __call__(
6668 .reshape (B , L , self .num_kv_heads , self .head_dim )
6769 .astype (mx .float32 )
6870 )
69- # offset = cache.offset
7071 projection_q = self .rope (projection_q , offset = slice (offset , offset + L ))
7172 projection_k = self .rope (projection_k , offset = slice (offset , offset + L ))
7273 projection_q = projection_q .transpose (0 , 2 , 1 , 3 )
7374 projection_k = projection_k .transpose (0 , 2 , 1 , 3 )
7475 projection_v = projection_v .transpose (0 , 2 , 1 , 3 )
75- # TODO: it is possible to get a sensible result without using a kv-cache? Otherwise we have to include kv-cache in week 1.
76- # mlx-lm's KvCache seems to do more than just caching, we could extract something out of it.
77- # projection_k, projection_v = cache.update_and_fetch(projection_k, projection_v)
78- assert (
79- projection_k .dtype == mx .float32
80- ) # TODO: can we use float16? also a test framework to ensure all data types are casted correctly.
81- assert projection_v .dtype == mx .float32
76+ projection_k , projection_v = cache .update_and_fetch (projection_k , projection_v )
8277 x = scaled_dot_product_attention_grouped (
8378 projection_q ,
8479 projection_k ,
@@ -157,8 +152,9 @@ def __call__(
157152 self ,
158153 x : mx .array ,
159154 offset : int ,
155+ cache : TinyKvCache ,
160156 ) -> mx .array :
161- r = self .self_attn (self .input_layernorm (x ), offset )
157+ r = self .self_attn (self .input_layernorm (x ), offset , cache )
162158 h = x + r
163159 r = self .mlp (self .post_attention_layernorm (h ))
164160 out = h + r
@@ -230,9 +226,10 @@ def __call__(
230226 self ,
231227 inputs : mx .array ,
232228 offset : int ,
229+ cache : list [TinyKvCache ],
233230 ) -> mx .array :
234231 h = self .embedding (inputs )
235232 for layer in range (self .num_hidden_layers ):
236- h = self .layers_inner [layer ](h , offset )
233+ h = self .layers_inner [layer ](h , offset , cache [ layer ] )
237234 h = self .norm (h )
238235 return linear (h , self .w_lm_head )
0 commit comments