Skip to content

Commit d484c56

Browse files
committed
Bugfix: Check cache keys as prefix to prompt tokens
1 parent b75fa96 commit d484c56

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

Diff for: llama_cpp/llama.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
7+
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
88
from collections import deque
99

1010
from . import llama_cpp
@@ -15,15 +15,34 @@ class LlamaCache:
1515
"""Cache for a llama.cpp model."""
1616

1717
def __init__(self):
18-
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
18+
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
19+
20+
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
21+
return [
22+
key
23+
for _, key in sorted(
24+
((len(key), key) for key in self.cache_state.keys()), reverse=True
25+
)
26+
]
27+
28+
def _find_key(
29+
self, key: Tuple[llama_cpp.llama_token, ...]
30+
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
31+
for k in self._sorted_keys():
32+
if key[: len(k)] == k:
33+
return k
34+
return None
1935

2036
def __getitem__(
2137
self, key: Sequence[llama_cpp.llama_token]
2238
) -> Optional["LlamaState"]:
23-
return self.cache_state.get(tuple(key), None)
39+
_key = self._find_key(tuple(key))
40+
if _key is None:
41+
return None
42+
return self.cache_state[_key]
2443

2544
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
26-
return tuple(key) in self.cache_state
45+
return self._find_key(tuple(key)) is not None
2746

2847
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
2948
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
@@ -295,7 +314,7 @@ def generate(
295314
if (
296315
reset
297316
and len(self.eval_tokens) > 0
298-
and self.eval_tokens == tokens[: len(self.eval_tokens)]
317+
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
299318
):
300319
if self.verbose:
301320
print("generate cache hit", file=sys.stderr)
@@ -438,6 +457,8 @@ def _create_completion(
438457

439458
if self.cache and len(completion_tokens) == 0:
440459
if prompt_tokens not in self.cache:
460+
if self.verbose:
461+
print("cache miss", file=sys.stderr)
441462
self.cache[prompt_tokens] = self.save_state()
442463

443464
completion_tokens.append(token)

0 commit comments

Comments
 (0)