forked from jenkinsci/resources-ai-chatbot-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory.py
More file actions
316 lines (246 loc) · 8.64 KB
/
memory.py
File metadata and controls
316 lines (246 loc) · 8.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
Handles in-memory chat session state (conversation memory).
Provides utility functions for session lifecycle.
"""
import asyncio
import uuid
from datetime import datetime, timedelta
from threading import Lock
from typing import Optional
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from api.config.loader import CONFIG
from api.services.sessionmanager import(
delete_session_file,
load_session,
session_exists_in_json,
append_message,
get_persisted_session_ids
)
# sessionId --> {"memory": ConversationBufferMemory, "last_accessed": datetime}
_sessions = {}
_lock = Lock()
_ROLE_TO_MESSAGE_CLASS = {
"human": HumanMessage,
"user": HumanMessage,
"ai": AIMessage,
"assistant": AIMessage,
"system": SystemMessage,
}
def init_session() -> str:
"""
Initialize a new chat session and store its memory object.
Returns:
str: A newly generated UUID representing the session ID.
"""
session_id = str(uuid.uuid4())
with _lock:
_sessions[session_id] = {
"memory": ConversationBufferMemory(return_messages=True),
"last_accessed": datetime.now()
}
return session_id
def _restore_persisted_message(memory: ConversationBufferMemory, message: object) -> None:
"""
Restore one persisted message into LangChain memory.
Persisted snapshots are dicts with {"role": ..., "content": ...}.
We convert them back to message objects so downstream code can safely
rely on attributes like msg.type and msg.content.
"""
if not isinstance(message, dict):
return
role = message.get("role", "human")
normalized_role = role.lower() if isinstance(role, str) else "human"
content = message.get("content", "")
if content is None:
content = ""
elif not isinstance(content, str):
content = str(content)
message_class = _ROLE_TO_MESSAGE_CLASS.get(normalized_role, HumanMessage)
memory.chat_memory.add_message(message_class(content=content))
def get_session(session_id: str) -> Optional[ConversationBufferMemory]:
"""
Retrieve the chat session memory for the given session ID.
Lazily restores from disk if missing in memory.
Args:
session_id (str): The session identifier.
Returns:
Optional[ConversationBufferMemory]: The memory object if found, else None.
"""
with _lock:
session_data = _sessions.get(session_id)
if session_data :
session_data["last_accessed"] = datetime.now()
return session_data["memory"]
history = load_session(session_id)
if not history:
return None
memory = ConversationBufferMemory(return_messages=True)
for msg in history:
_restore_persisted_message(memory, msg)
_sessions[session_id] = {
"memory": memory,
"last_accessed": datetime.now()
}
return memory
async def get_session_async(session_id: str) -> Optional[ConversationBufferMemory]:
"""
Async wrapper for get_session to prevent event loop blocking.
"""
return await asyncio.to_thread(get_session, session_id)
def persist_session(session_id: str)-> None:
"""
Persist the current session messages to disk.
Args:
session_id (str): The session identifier.
"""
session_data = get_session(session_id)
if session_data:
messages = [
{"role": msg.type, "content": msg.content}
for msg in session_data.chat_memory.messages
]
append_message(session_id, messages)
def delete_session(session_id: str) -> bool:
"""
Delete a chat session and its persisted data.
Args:
session_id (str): The session identifier.
Returns:
bool: True if the session existed and was deleted, False otherwise.
"""
with _lock:
if session_id is None:
return True
in_memory_deleted = _sessions.pop(session_id, None) is not None
if in_memory_deleted:
delete_session_file(session_id)
return in_memory_deleted
def session_exists(session_id: str) -> bool:
"""
Check if a chat session exists in memory.
Args:
session_id (str): The session identifier.
Returns:
bool: True if the session exists, False otherwise.
"""
with _lock:
return session_id in _sessions
def reset_sessions():
"""Helper function to clear all sessions. Useful for testing."""
with _lock:
_sessions.clear()
def reload_persisted_sessions() -> int:
"""
Load all persisted sessions from disk into memory.
Called once at application startup so that session_exists()
can remain a fast, memory-only check.
Returns:
int: The number of sessions restored.
"""
session_ids = get_persisted_session_ids()
loaded = 0
for session_id in session_ids:
if get_session(session_id) is not None:
loaded += 1
return loaded
def get_last_accessed(session_id: str) -> Optional[datetime]:
"""
Get the last accessed timestamp for a given session.
Args:
session_id (str): The session identifier.
Returns:
Optional[datetime]: The last accessed timestamp if session exists, else None.
"""
with _lock:
session_data = _sessions.get(session_id)
if session_data is not None:
return session_data["last_accessed"]
history = load_session(session_id)
if not history:
return None
return history["last_accessed"]
def set_last_accessed(session_id: str, timestamp: datetime) -> bool:
"""
Set the last accessed timestamp for a given session (for testing purposes).
Args:
session_id (str): The session identifier.
timestamp (datetime): The timestamp to set.
Returns:
bool: True if session exists and timestamp was set, False otherwise.
"""
with _lock:
session_data = _sessions.get(session_id)
if session_data:
session_data["last_accessed"] = timestamp
return True
history = load_session(session_id)
if not history:
return False
history["last_accessed"] = timestamp
return True
return False
def list_sessions(page: int = 1, page_size: int = 20) -> dict:
"""
Return a paginated list of all active in-memory sessions with basic metadata.
Each entry includes the session ID, number of messages exchanged, and the
ISO-8601 last-accessed timestamp.
Args:
page (int): 1-indexed page number. Defaults to 1.
page_size (int): Maximum sessions per page. Defaults to 20.
Returns:
dict: Contains ``sessions`` (list of metadata dicts), ``total`` (total
count before pagination), ``page``, and ``page_size``.
"""
with _lock:
all_ids = sorted(_sessions.keys())
total = len(all_ids)
start = (page - 1) * page_size
end = start + page_size
page_ids = all_ids[start:end]
sessions = []
for session_id in page_ids:
session_data = _sessions.get(session_id)
if session_data is None:
continue
message_count = len(session_data["memory"].chat_memory.messages)
last_accessed: datetime = session_data["last_accessed"]
sessions.append({
"session_id": session_id,
"message_count": message_count,
"last_accessed": last_accessed.isoformat(),
})
return {
"sessions": sessions,
"total": total,
"page": page,
"page_size": page_size,
}
def get_session_count() -> int:
"""
Get the total number of active sessions (for testing purposes).
Returns:
int: The number of active sessions.
"""
with _lock:
return len(_sessions)
def cleanup_expired_sessions() -> int:
"""
Remove sessions that have not been accessed within the configured timeout period.
Returns:
int: The number of sessions that were cleaned up.
"""
timeout_hours = CONFIG.get("session", {}).get("timeout_hours", 24)
now = datetime.now()
cutoff_time = now - timedelta(hours=timeout_hours)
with _lock:
expired_session_ids = [
session_id
for session_id, session_data in _sessions.items()
if session_data["last_accessed"] < cutoff_time
]
for session_id in expired_session_ids:
in_memory_deleted = _sessions.pop(session_id, None) is not None
if in_memory_deleted and session_exists_in_json(session_id):
delete_session_file(session_id)
return len(expired_session_ids)