4
4
import time
5
5
import math
6
6
import multiprocessing
7
- from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque
7
+ from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque , Tuple
8
8
from collections import deque
9
9
10
10
from . import llama_cpp
@@ -15,15 +15,34 @@ class LlamaCache:
15
15
"""Cache for a llama.cpp model."""
16
16
17
17
def __init__ (self ):
18
- self .cache_state : Dict [Sequence [llama_cpp .llama_token ], "LlamaState" ] = dict ()
18
+ self .cache_state : Dict [Tuple [llama_cpp .llama_token , ...], "LlamaState" ] = dict ()
19
+
20
+ def _sorted_keys (self ) -> List [Tuple [llama_cpp .llama_token , ...]]:
21
+ return [
22
+ key
23
+ for _ , key in sorted (
24
+ ((len (key ), key ) for key in self .cache_state .keys ()), reverse = True
25
+ )
26
+ ]
27
+
28
+ def _find_key (
29
+ self , key : Tuple [llama_cpp .llama_token , ...]
30
+ ) -> Optional [Tuple [llama_cpp .llama_token , ...]]:
31
+ for k in self ._sorted_keys ():
32
+ if key [: len (k )] == k :
33
+ return k
34
+ return None
19
35
20
36
def __getitem__ (
21
37
self , key : Sequence [llama_cpp .llama_token ]
22
38
) -> Optional ["LlamaState" ]:
23
- return self .cache_state .get (tuple (key ), None )
39
+ _key = self ._find_key (tuple (key ))
40
+ if _key is None :
41
+ return None
42
+ return self .cache_state [_key ]
24
43
25
44
def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
26
- return tuple (key ) in self . cache_state
45
+ return self . _find_key ( tuple (key )) is not None
27
46
28
47
def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
29
48
self .cache_state = dict () # NOTE: Currently limit to one cache entry.
@@ -295,7 +314,7 @@ def generate(
295
314
if (
296
315
reset
297
316
and len (self .eval_tokens ) > 0
298
- and self .eval_tokens == tokens [: len (self .eval_tokens )]
317
+ and tuple ( self .eval_tokens ) == tuple ( tokens [: len (self .eval_tokens )])
299
318
):
300
319
if self .verbose :
301
320
print ("generate cache hit" , file = sys .stderr )
@@ -438,6 +457,8 @@ def _create_completion(
438
457
439
458
if self .cache and len (completion_tokens ) == 0 :
440
459
if prompt_tokens not in self .cache :
460
+ if self .verbose :
461
+ print ("cache miss" , file = sys .stderr )
441
462
self .cache [prompt_tokens ] = self .save_state ()
442
463
443
464
completion_tokens .append (token )
0 commit comments