-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathcache.py
488 lines (401 loc) · 15.9 KB
/
cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = dict(tree_flatten([c.state for c in cache]))
cache_info = [c.meta_state for c in cache]
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = dict(tree_flatten([cache_info, metadata, cache_classes]))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v
class ChunkedKVCache(KVCache):
def __init__(self, chunk_size=None):
super().__init__()
self.chunk_size = chunk_size
self.start_position = 0
def maybe_trim_front(self):
# Maintain the cache below the chunk size
if self.keys is not None and self.keys.shape[2] >= self.chunk_size:
self.keys = self.keys[..., self.chunk_size :, :]
self.values = self.values[..., self.chunk_size :, :]
self.start_position += self.chunk_size
def update_and_fetch(self, keys, values):
prev = self.offset - self.start_position
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
end = self.offset - self.start_position
self.keys[..., prev:end, :] = keys
self.values[..., prev:end, :] = values
return self.keys[..., :end, :], self.values[..., :end, :]
def trim(self, n):
n = min(self.offset - self.start_position, n)
self.offset -= n
return n
@property
def meta_state(self):
return tuple(map(str, (self.chunk_size, self.start_position)))
@meta_state.setter
def meta_state(self, v):
self.chunk_size, self.start_position = map(int, v)