Skip to content

Commit 6b6f831

Browse files
committed
Mem long terme + chromadb
1 parent 68306f5 commit 6b6f831

7 files changed

Lines changed: 53 additions & 6 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
venv/
22
*.pyc
33
*.egg-info
4-
dist/
4+
dist/
5+
chroma_memory/

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch==2.4.0
22
transformers>=4.44.1,<5.0
3-
pytest
3+
pytest
4+
chromadb

src/.DS_Store

6 KB
Binary file not shown.

src/mem_db/_init_.py

Whitespace-only changes.

src/mem_db/vecto.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import chromadb
2+
from chromadb.errors import NotFoundError
3+
4+
chroma_client = chromadb.PersistentClient(path="./chroma_memory")
5+
6+
7+
def get_or_create_collection(collection_name):
8+
try:
9+
return chroma_client.get_collection(name=collection_name)
10+
except NotFoundError:
11+
return chroma_client.create_collection(name=collection_name)

src/shortterm_memory/ChatbotMemory.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,47 @@
11
import torch
22
from transformers import BartTokenizer, BartForConditionalGeneration
33
from logs.logger import logging
4+
from datetime import date
5+
import uuid
6+
from mem_db.vecto import get_or_create_collection
7+
48

59
# Détection automatique du device
610
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
711

812
# Paramètres globaux
913
MAX_MEMORY_SIZE = 2000 # Limite du nombre de messages
10-
MAX_TOKENS_PER_MESSAGE = 1000 # Limite pour compresser la mémoire
14+
MAX_TOKENS_IN_MEMORY = 1000 # Limite pour compresser la mémoire
1115
BATCH_SIZE = 5 # Taille du batch pour la compression
1216

17+
class BartSingleton:
18+
_instance = None
19+
20+
def __new__(cls):
21+
if cls._instance is None:
22+
logging.info("Instanciation du modèle BART...")
23+
cls._instance = super(BartSingleton, cls).__new__(cls)
24+
cls._instance.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
25+
cls._instance.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
26+
return cls._instance
27+
28+
@classmethod
29+
def reset(cls):
30+
"""
31+
Resets the BART singleton, allowing it to be reinitialized.
32+
This should only be used for testing purposes.
33+
"""
34+
logging.warning("Reset du singleton BART.")
35+
cls._instance = None
36+
1337
class ChatbotMemory:
1438
def __init__(self, conv:list=None):
1539
self.conversation_history = conv or []
16-
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
17-
self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
40+
bart = BartSingleton()
41+
self.tokenizer = bart.tokenizer
42+
self.model = bart.model
43+
44+
self.persistent_storage = get_or_create_collection("conv_memory")
1845

1946
def update_memory(self, user_input:str, bot_response:str)->None:
2047
"""
@@ -25,8 +52,15 @@ def update_memory(self, user_input:str, bot_response:str)->None:
2552
Returns:
2653
None """
2754
self.conversation_history.append({'user': user_input, 'bot': bot_response})
55+
date.today()
56+
57+
self.persistent_storage.add(
58+
documents=[f"user: {user_input} bot: {bot_response}"],
59+
ids=[str(uuid.uuid4())],
60+
metadatas=[{"type": "chat_entry"}]
61+
)
2862

29-
if self.memory_counter() > MAX_TOKENS_PER_MESSAGE:
63+
if self.memory_counter() > MAX_TOKENS_IN_MEMORY:
3064
self.conversation_history = self.compressed_memory()
3165
logging.info("Mémoire compressée.")
3266

0 commit comments

Comments
 (0)