-
Notifications
You must be signed in to change notification settings - Fork 279
/
Copy pathtoken_memory.py
60 lines (49 loc) · 1.71 KB
/
token_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import asyncio
import math
import sys
import traceback
from beeai_framework.adapters.ollama import OllamaChatModel
from beeai_framework.backend import Role, SystemMessage, UserMessage
from beeai_framework.errors import FrameworkError
from beeai_framework.memory import TokenMemory
# Initialize the LLM
llm = OllamaChatModel()
# Initialize TokenMemory with handlers
memory = TokenMemory(
llm=llm,
max_tokens=None, # Will be inferred from LLM
capacity_threshold=0.75,
sync_threshold=0.25,
handlers={
"removal_selector": lambda messages: next((msg for msg in messages if msg.role != Role.SYSTEM), messages[0]),
"estimate": lambda msg: math.ceil((len(msg.role) + len(msg.text)) / 4),
},
)
async def main() -> None:
# Add system message
system_message = SystemMessage("You are a helpful assistant.")
await memory.add(system_message)
print(f"Added system message (hash: {hash(system_message)})")
# Add user message
user_message = UserMessage("Hello world!")
await memory.add(user_message)
print(f"Added user message (hash: {hash(user_message)})")
# Check initial memory state
print("\nInitial state:")
print(f"Is Dirty: {memory.is_dirty}")
print(f"Tokens Used: {memory.tokens_used}")
# Sync token counts
await memory.sync()
print("\nAfter sync:")
print(f"Is Dirty: {memory.is_dirty}")
print(f"Tokens Used: {memory.tokens_used}")
# Print all messages
print("\nMessages in memory:")
for msg in memory.messages:
print(f"{msg.role}: {msg.text} (hash: {hash(msg)})")
if __name__ == "__main__":
try:
asyncio.run(main())
except FrameworkError as e:
traceback.print_exc()
sys.exit(e.explain())