1+ import asyncio
2+ from langgraph .store .base import BaseStore , SearchItem
3+ from langgraph .store .memory import InMemoryStore
4+ from langgraph .store .postgres .aio import AsyncPostgresStore
5+
6+ from src .services .db import get_store_in_memory
7+ from src .schemas .entities import SearchFilter
8+ from src .constants import THREAD_SNAPSHOT_MESSAGE_COUNT
9+ from src .repos .base_repo import BaseRepo
10+ from src .schemas .entities .store import ThreadSnapshot
11+ from src .utils .logger import logger
12+ from src .utils .format import format_xml_thread
13+ from src .utils .messages import from_message_to_dict
14+
15+
16+ FIELDS = ["messages" , "files" ]
17+
18+ class ThreadRepo (BaseRepo ):
19+ def __init__ (self , user_id : str , store : BaseStore = get_store_in_memory (fields = FIELDS )):
20+ ## Add fields to the store (if supported)
21+ self .user_id = user_id
22+ self .store : BaseStore = store
23+
24+ try :
25+ self .store .fields = FIELDS
26+ except AttributeError :
27+ pass
28+ super ().__init__ (user_id = user_id , store = store , entity_type = "threads" )
29+
30+
31+ async def search (
32+ self ,
33+ search_filter : SearchFilter ,
34+ ) -> list [dict ]:
35+ try :
36+ max_retries = 3
37+ retry_delay = 1 # seconds
38+
39+ for attempt in range (max_retries ):
40+ try :
41+ async with self .store as store :
42+ if search_filter .query :
43+ queried_threads : list [SearchItem ] = await store .asearch (
44+ self ._get_namespace (),
45+ limit = search_filter .limit ,
46+ filter = search_filter .filter ,
47+ query = search_filter .query ,
48+ )
49+ return [
50+ ThreadSnapshot (
51+ id = thread .key ,
52+ messages = thread .value ["messages" ],
53+ files = thread .value ["files" ],
54+ score = thread .score ,
55+ updated_at = thread .updated_at
56+ ).model_dump (exclude_none = True ) for thread in queried_threads
57+ ]
58+ threads = await store .asearch (
59+ self ._get_namespace (),
60+ limit = search_filter .limit ,
61+ filter = search_filter .filter ,
62+ )
63+ return sorted (
64+ [thread .dict () for thread in threads ],
65+ key = lambda x : x .get ("updated_at" ),
66+ reverse = True ,
67+ )
68+ except Exception as e :
69+ error_msg = str (e ).lower ()
70+ if "connection" in error_msg and "closed" in error_msg :
71+ logger .warning (
72+ f"Store connection closed on attempt { attempt + 1 } /{ max_retries } : { e } "
73+ )
74+ if attempt < max_retries - 1 :
75+ await asyncio .sleep (
76+ retry_delay * (2 ** attempt )
77+ ) # Exponential backoff
78+ continue
79+ raise e
80+ except Exception as e :
81+ logger .error (f"Error searching threads: { e } " )
82+ return []
83+
84+ async def update (self , thread_id : str , data : dict ):
85+
86+ # Extract last human message for storage
87+ messages = data .get ("messages" , [])
88+ messages = from_message_to_dict (messages , include_tool_calls = False )
89+ recent_messages = (
90+ messages [- THREAD_SNAPSHOT_MESSAGE_COUNT :]
91+ if len (messages ) > THREAD_SNAPSHOT_MESSAGE_COUNT
92+ else messages
93+ )
94+
95+ data ["messages" ] = recent_messages
96+
97+ await self .store .aput (
98+ namespace = self ._get_namespace (), key = thread_id , value = data
99+ )
100+
101+ return True
102+
103+
104+ async def get (self , thread_id : str ) -> dict :
105+ return await self ._get (thread_id )
106+
107+ async def delete (self , thread_id : str ) -> bool :
108+ try :
109+ await self ._delete (thread_id )
110+ logger .info (f"Thread { thread_id } deleted successfully" )
111+ return True
112+ except Exception as e :
113+ logger .error (f"Error deleting thread: { e } " )
114+ return False
115+
116+ async def _upsert_snapshot (self , thread_id : str , messages : list ) -> bool :
117+ """Create or update a thread snapshot with recent messages.
118+
119+ Note: messages should already be filtered to recent messages before calling this method.
120+ """
121+ try :
122+ # Extract recent messages for snapshot (last N messages)
123+ recent_messages = (
124+ messages [- THREAD_SNAPSHOT_MESSAGE_COUNT :]
125+ if len (messages ) > THREAD_SNAPSHOT_MESSAGE_COUNT
126+ else messages
127+ )
128+
129+ # Format messages as "Role: content" pairs
130+ page_content = format_xml_thread (recent_messages , include_tool_calls = False )
131+
132+ # Create snapshot with metadata
133+ snapshot = ThreadSnapshot (
134+ thread_id = thread_id ,
135+ page_content = page_content ,
136+ metadata = {
137+ "thread_id" : thread_id ,
138+ "message_count" : len (messages ),
139+ }
140+ )
141+
142+ await self ._set (thread_id , snapshot )
143+ return True
144+
145+ except Exception as e :
146+ logger .error (f"Failed to upsert thread snapshot for { thread_id } : { e } " )
147+ return False
148+
0 commit comments