diff --git a/src/praisonai-agents/praisonaiagents/memory/memory.py b/src/praisonai-agents/praisonaiagents/memory/memory.py index 9c58986a4..cdcafce0e 100644 --- a/src/praisonai-agents/praisonaiagents/memory/memory.py +++ b/src/praisonai-agents/praisonaiagents/memory/memory.py @@ -3,6 +3,7 @@ import json import time import shutil +import threading from typing import Any, Dict, List, Optional, Union, Literal import logging from datetime import datetime @@ -189,6 +190,16 @@ def __init__(self, config: Dict[str, Any], verbose: int = 0): self.cfg = config or {} self.verbose = verbose + # Thread-local storage for SQLite connections (thread-safe) + self._local = threading.local() + + # Write lock for serializing database modifications (thread-safe) + self._write_lock = threading.Lock() + + # Connection registry for cleanup across threads (use regular set with careful cleanup) + self._all_connections = set() + self._connection_lock = threading.Lock() # Protect the connection registry + # Set logger level based on verbose if verbose >= 5: logger.setLevel(logging.INFO) @@ -247,6 +258,50 @@ def __init__(self, config: Dict[str, Any], verbose: int = 0): elif self.use_rag: self._init_chroma() + def _get_stm_conn(self): + """Get thread-local short-term memory SQLite connection.""" + if not hasattr(self._local, 'stm_conn') or self._local.stm_conn is None: + self._local.stm_conn = sqlite3.connect( + self.short_db, + check_same_thread=False, # Allow cross-thread cleanup + timeout=30.0 # 30 second timeout for lock contention + ) + # Configure busy timeout for better contention handling + self._local.stm_conn.execute("PRAGMA busy_timeout=30000") # 30 seconds + + # Enable WAL mode for concurrent read/write without blocking + result = self._local.stm_conn.execute("PRAGMA journal_mode=WAL").fetchone() + if result and result[0].upper() != 'WAL': + logger.warning(f"WAL mode not enabled for STM, got: {result[0]}") + self._local.stm_conn.commit() + + # Register connection for cleanup + with self._connection_lock: + self._all_connections.add(self._local.stm_conn) + return self._local.stm_conn + + def _get_ltm_conn(self): + """Get thread-local long-term memory SQLite connection.""" + if not hasattr(self._local, 'ltm_conn') or self._local.ltm_conn is None: + self._local.ltm_conn = sqlite3.connect( + self.long_db, + check_same_thread=False, # Allow cross-thread cleanup + timeout=30.0 # 30 second timeout for lock contention + ) + # Configure busy timeout for better contention handling + self._local.ltm_conn.execute("PRAGMA busy_timeout=30000") # 30 seconds + + # Enable WAL mode for concurrent read/write without blocking + result = self._local.ltm_conn.execute("PRAGMA journal_mode=WAL").fetchone() + if result and result[0].upper() != 'WAL': + logger.warning(f"WAL mode not enabled for LTM, got: {result[0]}") + self._local.ltm_conn.commit() + + # Register connection for cleanup + with self._connection_lock: + self._all_connections.add(self._local.ltm_conn) + return self._local.ltm_conn + def _log_verbose(self, msg: str, level: int = logging.INFO): """Only log if verbose >= 5""" if self.verbose >= 5: @@ -276,7 +331,7 @@ def _emit_memory_event(self, event_type: str, memory_type: str, def _init_stm(self): """Creates or verifies short-term memory table.""" os.makedirs(os.path.dirname(self.short_db) or ".", exist_ok=True) - conn = sqlite3.connect(self.short_db) + conn = self._get_stm_conn() c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS short_mem ( @@ -287,12 +342,11 @@ def _init_stm(self): ) """) conn.commit() - conn.close() def _init_ltm(self): """Creates or verifies long-term memory table.""" os.makedirs(os.path.dirname(self.long_db) or ".", exist_ok=True) - conn = sqlite3.connect(self.long_db) + conn = self._get_ltm_conn() c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS long_mem ( @@ -303,7 +357,6 @@ def _init_ltm(self): ) """) conn.commit() - conn.close() def _init_mem0(self): """Initialize Mem0 client for agent or user memory with optional graph support.""" @@ -579,15 +632,15 @@ def store_short_term( logger.error(f"Failed to store in MongoDB short-term memory: {e}") raise - # Existing SQLite store logic + # Existing SQLite store logic (with write lock for concurrency safety) try: - conn = sqlite3.connect(self.short_db) - conn.execute( - "INSERT INTO short_mem (id, content, meta, created_at) VALUES (?,?,?,?)", - (ident, text, json.dumps(metadata), created_at) - ) - conn.commit() - conn.close() + conn = self._get_stm_conn() + with self._write_lock: # Serialize write operations + conn.execute( + "INSERT INTO short_mem (id, content, meta, created_at) VALUES (?,?,?,?)", + (ident, text, json.dumps(metadata), created_at) + ) + conn.commit() logger.info(f"Successfully stored in SQLite short-term memory with ID: {ident}") except Exception as e: logger.error(f"Failed to store in SQLite short-term memory: {e}") @@ -713,13 +766,12 @@ def search_short_term( else: # Local fallback - conn = sqlite3.connect(self.short_db) + conn = self._get_stm_conn() c = conn.cursor() rows = c.execute( "SELECT id, content, meta FROM short_mem WHERE content LIKE ? LIMIT ?", (f"%{query}%", limit) ).fetchall() - conn.close() results = [] for row in rows: @@ -739,10 +791,10 @@ def search_short_term( def reset_short_term(self): """Completely clears short-term memory.""" - conn = sqlite3.connect(self.short_db) - conn.execute("DELETE FROM short_mem") - conn.commit() - conn.close() + conn = self._get_stm_conn() + with self._write_lock: # Serialize write operations + conn.execute("DELETE FROM short_mem") + conn.commit() # ------------------------------------------------------------------------- # Long-Term Methods @@ -813,15 +865,15 @@ def store_long_term( logger.error(f"Failed to store in MongoDB long-term memory: {e}") # Continue to SQLite fallback - # Store in SQLite + # Store in SQLite (with write lock for concurrency safety) try: - conn = sqlite3.connect(self.long_db) - conn.execute( - "INSERT INTO long_mem (id, content, meta, created_at) VALUES (?,?,?,?)", - (ident, text, json.dumps(metadata), created) - ) - conn.commit() - conn.close() + conn = self._get_ltm_conn() + with self._write_lock: # Serialize write operations + conn.execute( + "INSERT INTO long_mem (id, content, meta, created_at) VALUES (?,?,?,?)", + (ident, text, json.dumps(metadata), created) + ) + conn.commit() logger.info(f"Successfully stored in SQLite with ID: {ident}") except Exception as e: logger.error(f"Error storing in SQLite: {e}") @@ -1002,13 +1054,12 @@ def search_long_term( self._log_verbose(f"Error searching ChromaDB: {e}", logging.ERROR) # Always try SQLite as fallback or additional source - conn = sqlite3.connect(self.long_db) + conn = self._get_ltm_conn() c = conn.cursor() rows = c.execute( "SELECT id, content, meta, created_at FROM long_mem WHERE content LIKE ? LIMIT ?", (f"%{query}%", limit) ).fetchall() - conn.close() for row in rows: meta = json.loads(row[2] or "{}") @@ -1051,10 +1102,10 @@ def search_long_term( def reset_long_term(self): """Clear local LTM DB, plus Chroma, MongoDB, or mem0 if in use.""" - conn = sqlite3.connect(self.long_db) - conn.execute("DELETE FROM long_mem") - conn.commit() - conn.close() + conn = self._get_ltm_conn() + with self._write_lock: # Serialize write operations + conn.execute("DELETE FROM long_mem") + conn.commit() if self.use_mem0 and hasattr(self, "mem0_client"): # Mem0 has no universal reset API. Could implement partial or no-op. @@ -1085,16 +1136,16 @@ def delete_short_term(self, memory_id: str) -> bool: """ deleted = False - # Delete from SQLite + # Delete from SQLite (with write lock for concurrency safety) try: - conn = sqlite3.connect(self.short_db) - cursor = conn.execute( - "DELETE FROM short_mem WHERE id = ?", (memory_id,) - ) - if cursor.rowcount > 0: - deleted = True - conn.commit() - conn.close() + conn = self._get_stm_conn() + with self._write_lock: # Serialize write operations + cursor = conn.execute( + "DELETE FROM short_mem WHERE id = ?", (memory_id,) + ) + if cursor.rowcount > 0: + deleted = True + conn.commit() except Exception as e: self._log_verbose(f"Error deleting from SQLite short-term: {e}", logging.ERROR) @@ -1126,16 +1177,16 @@ def delete_long_term(self, memory_id: str) -> bool: """ deleted = False - # Delete from SQLite + # Delete from SQLite (with write lock for concurrency safety) try: - conn = sqlite3.connect(self.long_db) - cursor = conn.execute( - "DELETE FROM long_mem WHERE id = ?", (memory_id,) - ) - if cursor.rowcount > 0: - deleted = True - conn.commit() - conn.close() + conn = self._get_ltm_conn() + with self._write_lock: # Serialize write operations + cursor = conn.execute( + "DELETE FROM long_mem WHERE id = ?", (memory_id,) + ) + if cursor.rowcount > 0: + deleted = True + conn.commit() except Exception as e: self._log_verbose(f"Error deleting from SQLite long-term: {e}", logging.ERROR) @@ -1790,10 +1841,9 @@ def get_all_memories(self) -> List[Dict[str, Any]]: try: # Get short-term memories - conn = sqlite3.connect(self.short_db) + conn = self._get_stm_conn() c = conn.cursor() rows = c.execute("SELECT id, content, meta, created_at FROM short_mem").fetchall() - conn.close() for row in rows: meta = json.loads(row[2] or "{}") @@ -1806,10 +1856,9 @@ def get_all_memories(self) -> List[Dict[str, Any]]: }) # Get long-term memories - conn = sqlite3.connect(self.long_db) + conn = self._get_ltm_conn() c = conn.cursor() rows = c.execute("SELECT id, content, meta, created_at FROM long_mem").fetchall() - conn.close() for row in rows: meta = json.loads(row[2] or "{}") @@ -1873,3 +1922,57 @@ def get_learn_context(self) -> str: if self.learn is None: return "" return self.learn.to_system_prompt_context() + + def close_connections(self): + """ + Close database connections. + + Closes the current thread's connections and attempts to close all known + connections from other threads. Each thread should call this method before + terminating to ensure proper cleanup. + """ + # Close current thread's connections + if hasattr(self._local, 'stm_conn') and self._local.stm_conn: + try: + self._local.stm_conn.close() + self._local.stm_conn = None + except Exception as e: + logger.warning(f"Error closing current thread's STM connection: {e}") + + if hasattr(self._local, 'ltm_conn') and self._local.ltm_conn: + try: + self._local.ltm_conn.close() + self._local.ltm_conn = None + except Exception as e: + logger.warning(f"Error closing current thread's LTM connection: {e}") + + # Close all known connections from the registry + with self._connection_lock: # Ensure thread safety during cleanup + connections_to_close = list(self._all_connections) + for conn in connections_to_close: + try: + conn.close() + except Exception as e: + logger.debug(f"Error closing registered connection: {e}") + # Clear the registry + self._all_connections.clear() + + def __enter__(self): + """Allow Memory to be used as a context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensure connections are closed when leaving a context manager block.""" + self.close_connections() + + def __del__(self): + """ + Attempt to clean up any open SQLite connections when this instance + is garbage-collected. Errors are suppressed to avoid issues during + interpreter shutdown. + """ + try: + self.close_connections() + except Exception: + # Best-effort cleanup during garbage collection + pass