diff --git a/src/praisonai-agents/praisonaiagents/agent/agent.py b/src/praisonai-agents/praisonaiagents/agent/agent.py index 3d47b52ca..1a1334e03 100644 --- a/src/praisonai-agents/praisonaiagents/agent/agent.py +++ b/src/praisonai-agents/praisonaiagents/agent/agent.py @@ -6,8 +6,12 @@ import contextlib import threading from typing import List, Optional, Any, Dict, Union, Literal, TYPE_CHECKING, Callable, Generator +from collections import OrderedDict import inspect +# Module-level logger for thread safety errors and debugging +logger = logging.getLogger(__name__) + # ============================================================================ # Performance: Lazy imports for heavy dependencies # Rich, LLM, and display utilities are only imported when needed (output=verbose) @@ -1511,9 +1515,12 @@ def __init__( self.embedder_config = embedder_config self.knowledge = knowledge self.use_system_prompt = use_system_prompt - # Thread-safe chat_history with lazy lock for concurrent access + # Thread-safe chat_history with eager lock initialization self.chat_history = [] - self.__history_lock = None # Lazy initialized + self.__history_lock = threading.Lock() # Eager initialization to prevent race conditions + + # Thread-safe snapshot/redo stack lock - always available even when autonomy is disabled + self.__snapshot_lock = threading.Lock() self.markdown = markdown self.stream = stream self.metrics = metrics @@ -1632,10 +1639,11 @@ def __init__( # P8/G11: Tool timeout - prevent slow tools from blocking self._tool_timeout = tool_timeout - # Cache for system prompts and formatted tools with lazy thread-safe lock - self._system_prompt_cache = {} - self._formatted_tools_cache = {} - self.__cache_lock = None # Lazy initialized RLock + # Cache for system prompts and formatted tools with eager thread-safe lock + # Use OrderedDict for LRU behavior + self._system_prompt_cache = OrderedDict() + self._formatted_tools_cache = OrderedDict() + self.__cache_lock = threading.RLock() # Eager initialization to prevent race conditions # Limit cache size to prevent unbounded growth self._max_cache_size = 100 @@ -1747,20 +1755,106 @@ def _telemetry(self): @property def _history_lock(self): - """Lazy-loaded history lock for thread-safe chat history access.""" - if self.__history_lock is None: - import threading - self.__history_lock = threading.Lock() + """Thread-safe chat history lock.""" return self.__history_lock @property def _cache_lock(self): - """Lazy-loaded cache lock for thread-safe cache access.""" - if self.__cache_lock is None: - import threading - self.__cache_lock = threading.RLock() + """Thread-safe cache lock.""" return self.__cache_lock + @property + def _snapshot_lock(self): + """Thread-safe snapshot/redo stack lock.""" + return self.__snapshot_lock + + def _cache_put(self, cache_dict, key, value): + """Thread-safe LRU cache put operation. + + Args: + cache_dict: The cache dictionary (OrderedDict) + key: Cache key + value: Value to cache + """ + with self._cache_lock: + # Move to end if already exists (LRU update) + if key in cache_dict: + del cache_dict[key] + + # Add new entry + cache_dict[key] = value + + # Evict oldest if over limit + while len(cache_dict) > self._max_cache_size: + cache_dict.popitem(last=False) # Remove oldest (FIFO) + + def _add_to_chat_history(self, role, content): + """Thread-safe method to add messages to chat history. + + Args: + role: Message role ("user", "assistant", "system") + content: Message content + """ + with self._history_lock: + self.chat_history.append({"role": role, "content": content}) + + def _add_to_chat_history_if_not_duplicate(self, role, content): + """Thread-safe method to add messages to chat history only if not duplicate. + + Atomically checks for duplicate and adds message under the same lock to prevent TOCTOU races. + + Args: + role: Message role ("user", "assistant", "system") + content: Message content + + Returns: + bool: True if message was added, False if duplicate was detected + """ + with self._history_lock: + # Check for duplicate within the same critical section + if (self.chat_history and + self.chat_history[-1].get("role") == role and + self.chat_history[-1].get("content") == content): + return False + + # Not a duplicate, add the message + self.chat_history.append({"role": role, "content": content}) + return True + + def _get_chat_history_length(self): + """Thread-safe method to get chat history length.""" + with self._history_lock: + return len(self.chat_history) + + def _truncate_chat_history(self, length): + """Thread-safe method to truncate chat history to specified length. + + Args: + length: Target length for chat history + """ + with self._history_lock: + self.chat_history = self.chat_history[:length] + + def _cache_get(self, cache_dict, key): + """Thread-safe LRU cache get operation. + + Args: + cache_dict: The cache dictionary (OrderedDict) + key: Cache key + + Returns: + Value if found, None otherwise + """ + with self._cache_lock: + if key not in cache_dict: + return None + + # Move to end (mark as recently used) + value = cache_dict[key] + del cache_dict[key] + cache_dict[key] = value + return value + @property def auto_memory(self): """AutoMemory instance for automatic memory extraction.""" @@ -2220,19 +2314,23 @@ def undo(self) -> bool: result = agent.start("Refactor utils.py") agent.undo() # Restore original files """ - if self._file_snapshot is None or not self._snapshot_stack: - return False - try: - target_hash = self._snapshot_stack.pop() - # Get current hash before restore (for redo) - current_hash = self._file_snapshot.get_current_hash() - if current_hash: - self._redo_stack.append(current_hash) - self._file_snapshot.restore(target_hash) - return True - except Exception as e: - logger.debug(f"Undo failed: {e}") + if self._file_snapshot is None: return False + + with self._snapshot_lock: + if not self._snapshot_stack: + return False + try: + target_hash = self._snapshot_stack.pop() + # Get current hash before restore (for redo) + current_hash = self._file_snapshot.get_current_hash() + if current_hash: + self._redo_stack.append(current_hash) + self._file_snapshot.restore(target_hash) + return True + except Exception as e: + logger.debug(f"Undo failed: {e}") + return False def redo(self) -> bool: """Redo a previously undone set of file changes. @@ -2242,18 +2340,22 @@ def redo(self) -> bool: Returns: True if redo was successful, False if nothing to redo. """ - if self._file_snapshot is None or not self._redo_stack: - return False - try: - target_hash = self._redo_stack.pop() - current_hash = self._file_snapshot.get_current_hash() - if current_hash: - self._snapshot_stack.append(current_hash) - self._file_snapshot.restore(target_hash) - return True - except Exception as e: - logger.debug(f"Redo failed: {e}") + if self._file_snapshot is None: return False + + with self._snapshot_lock: + if not self._redo_stack: + return False + try: + target_hash = self._redo_stack.pop() + current_hash = self._file_snapshot.get_current_hash() + if current_hash: + self._snapshot_stack.append(current_hash) + self._file_snapshot.restore(target_hash) + return True + except Exception as e: + logger.debug(f"Redo failed: {e}") + return False def diff(self, from_hash: Optional[str] = None): """Get file diffs from autonomous execution. @@ -2279,8 +2381,11 @@ def diff(self, from_hash: Optional[str] = None): return [] try: base = from_hash - if base is None and self._snapshot_stack: - base = self._snapshot_stack[0] + if base is None: + # Protect snapshot stack read with lock to prevent TOCTOU with undo/redo + with self._snapshot_lock: + if self._snapshot_stack: + base = self._snapshot_stack[0] if base is None: return [] return self._file_snapshot.diff(base) @@ -2477,8 +2582,9 @@ def run_autonomous( if self._file_snapshot is not None and self.autonomy_config.get("snapshot", False): try: snap_info = self._file_snapshot.track(message="pre-autonomous") - self._snapshot_stack.append(snap_info.commit_hash) - self._redo_stack.clear() + with self._snapshot_lock: + self._snapshot_stack.append(snap_info.commit_hash) + self._redo_stack.clear() except Exception as e: logging.debug(f"Pre-autonomous snapshot failed: {e}") @@ -4304,8 +4410,9 @@ def _build_system_prompt(self, tools=None): tools_key = self._get_tools_cache_key(tools) cache_key = f"{self.role}:{self.goal}:{tools_key}" - if cache_key in self._system_prompt_cache: - return self._system_prompt_cache[cache_key] + cached_prompt = self._cache_get(self._system_prompt_cache, cache_key) + if cached_prompt is not None: + return cached_prompt else: cache_key = None # Don't cache when memory is enabled @@ -4371,9 +4478,9 @@ def _build_system_prompt(self, tools=None): system_prompt += "\n\nExplain Before Acting: Before calling a tool, provide a brief one-sentence explanation of what you are about to do and why. Skip explanations only for repetitive low-level operations where narration would be noisy. When performing a batch of similar operations (e.g. searching for multiple items), explain the group once rather than narrating each call individually." # Cache the generated system prompt (only if cache_key is set, i.e., memory not enabled) - # Simple cache size limit to prevent unbounded growth - if cache_key and len(self._system_prompt_cache) < self._max_cache_size: - self._system_prompt_cache[cache_key] = system_prompt + # Use LRU eviction to prevent unbounded growth + if cache_key: + self._cache_put(self._system_prompt_cache, cache_key, system_prompt) return system_prompt def _build_response_format(self, schema_model): @@ -4567,8 +4674,9 @@ def _format_tools_for_completion(self, tools=None): # Check cache first tools_key = self._get_tools_cache_key(tools) - if tools_key in self._formatted_tools_cache: - return self._formatted_tools_cache[tools_key] + cached_tools = self._cache_get(self._formatted_tools_cache, tools_key) + if cached_tools is not None: + return cached_tools formatted_tools = [] for tool in tools: @@ -4619,10 +4727,8 @@ def _format_tools_for_completion(self, tools=None): logging.error(f"Tools are not JSON serializable: {e}") return [] - # Cache the formatted tools - # Simple cache size limit to prevent unbounded growth - if len(self._formatted_tools_cache) < self._max_cache_size: - self._formatted_tools_cache[tools_key] = formatted_tools + # Cache the formatted tools with LRU eviction + self._cache_put(self._formatted_tools_cache, tools_key, formatted_tools) return formatted_tools def generate_task(self) -> 'Task': @@ -6279,12 +6385,9 @@ def _chat_impl(self, prompt, temperature, tools, output_json, output_pydantic, r # Extract text from multimodal prompts normalized_content = next((item["text"] for item in prompt if item.get("type") == "text"), "") - # Prevent duplicate messages - if not (self.chat_history and - self.chat_history[-1].get("role") == "user" and - self.chat_history[-1].get("content") == normalized_content): - # Add user message to chat history BEFORE LLM call so handoffs can access it - self.chat_history.append({"role": "user", "content": normalized_content}) + # Add user message to chat history BEFORE LLM call so handoffs can access it + # Use atomic check-then-act to prevent TOCTOU race conditions + if self._add_to_chat_history_if_not_duplicate("user", normalized_content): # Persist user message to DB self._persist_message("user", normalized_content) @@ -6334,7 +6437,7 @@ def _chat_impl(self, prompt, temperature, tools, output_json, output_pydantic, r response_text = self.llm_instance.get_response(**llm_kwargs) - self.chat_history.append({"role": "assistant", "content": response_text}) + self._add_to_chat_history("assistant", response_text) # Persist assistant message to DB self._persist_message("assistant", response_text) @@ -8595,12 +8698,12 @@ async def handle_agent_query(request: Request, query_data: Optional[AgentQuery] print(f"🚀 Agent '{self.name}' available at http://{host}:{port}") - # Start the server if it's not already running for this port + # Check and mark server as started atomically to prevent race conditions should_start = not _server_started.get(port, False) if should_start: _server_started[port] = True - # Server start/wait outside the lock to avoid holding it during sleep + # Server start/wait outside the lock to avoid holding it during sleep if should_start: # Start the server in a separate thread def run_server():