11import torch
22from transformers import BartTokenizer , BartForConditionalGeneration
33from logs .logger import logging
4+ from datetime import date
5+ import uuid
6+ from mem_db .vecto import get_or_create_collection
7+
48
59# Détection automatique du device
610device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
711
812# Paramètres globaux
913MAX_MEMORY_SIZE = 2000 # Limite du nombre de messages
10- MAX_TOKENS_PER_MESSAGE = 1000 # Limite pour compresser la mémoire
14+ MAX_TOKENS_IN_MEMORY = 1000 # Limite pour compresser la mémoire
1115BATCH_SIZE = 5 # Taille du batch pour la compression
1216
17+ class BartSingleton :
18+ _instance = None
19+
20+ def __new__ (cls ):
21+ if cls ._instance is None :
22+ logging .info ("Instanciation du modèle BART..." )
23+ cls ._instance = super (BartSingleton , cls ).__new__ (cls )
24+ cls ._instance .tokenizer = BartTokenizer .from_pretrained ('facebook/bart-large-cnn' )
25+ cls ._instance .model = BartForConditionalGeneration .from_pretrained ('facebook/bart-large-cnn' ).to (device )
26+ return cls ._instance
27+
28+ @classmethod
29+ def reset (cls ):
30+ """
31+ Resets the BART singleton, allowing it to be reinitialized.
32+ This should only be used for testing purposes.
33+ """
34+ logging .warning ("Reset du singleton BART." )
35+ cls ._instance = None
36+
1337class ChatbotMemory :
1438 def __init__ (self , conv :list = None ):
1539 self .conversation_history = conv or []
16- self .tokenizer = BartTokenizer .from_pretrained ('facebook/bart-large-cnn' )
17- self .model = BartForConditionalGeneration .from_pretrained ('facebook/bart-large-cnn' ).to (device )
40+ bart = BartSingleton ()
41+ self .tokenizer = bart .tokenizer
42+ self .model = bart .model
43+
44+ self .persistent_storage = get_or_create_collection ("conv_memory" )
1845
1946 def update_memory (self , user_input :str , bot_response :str )-> None :
2047 """
@@ -25,8 +52,15 @@ def update_memory(self, user_input:str, bot_response:str)->None:
2552 Returns:
2653 None """
2754 self .conversation_history .append ({'user' : user_input , 'bot' : bot_response })
55+ date .today ()
56+
57+ self .persistent_storage .add (
58+ documents = [f"user: { user_input } bot: { bot_response } " ],
59+ ids = [str (uuid .uuid4 ())],
60+ metadatas = [{"type" : "chat_entry" }]
61+ )
2862
29- if self .memory_counter () > MAX_TOKENS_PER_MESSAGE :
63+ if self .memory_counter () > MAX_TOKENS_IN_MEMORY :
3064 self .conversation_history = self .compressed_memory ()
3165 logging .info ("Mémoire compressée." )
3266
0 commit comments