diff --git a/src/gaia/agents/base/agent.py b/src/gaia/agents/base/agent.py index 082c28a19..3ce6d9a4a 100644 --- a/src/gaia/agents/base/agent.py +++ b/src/gaia/agents/base/agent.py @@ -19,6 +19,7 @@ from gaia.agents.base.console import AgentConsole, SilentConsole from gaia.agents.base.errors import format_execution_trace +from gaia.agents.base.tool_loader import ToolLoader from gaia.agents.base.tools import _TOOL_REGISTRY # First-party imports @@ -216,6 +217,11 @@ def __init__( self.total_plan_steps = 0 self.plan_iterations = 0 # Track number of plan cycles + # Bundle-based tool loading (opt-in: subclasses set self.tool_loader + # before calling super().__init__ or in _register_tools). + if not hasattr(self, "tool_loader"): + self.tool_loader: Optional[ToolLoader] = None + # Initialize the console/output handler for display # If output_handler is provided, use it; otherwise create based on silent_mode if output_handler is not None: @@ -309,10 +315,16 @@ def _get_mixin_prompts(self) -> list[str]: return prompts - def _compose_system_prompt(self) -> str: + def _compose_system_prompt( + self, tool_registry: Optional[Dict[str, dict]] = None + ) -> str: """ Compose final system prompt from mixin fragments + agent custom + tools + format. + Parameters: + tool_registry: If provided, use this filtered tool dict instead of the + full ``_TOOL_REGISTRY``. Passed through to ``_format_tools_for_prompt``. + Override this method for complete control over prompt composition order. Returns: @@ -340,7 +352,7 @@ def _compose_system_prompt(self) -> str: # Add tool descriptions (if tools registered) if hasattr(self, "_format_tools_for_prompt"): - tools_description = self._format_tools_for_prompt() + tools_description = self._format_tools_for_prompt(registry=tool_registry) if tools_description: parts.append(f"==== AVAILABLE TOOLS ====\n{tools_description}") @@ -416,11 +428,22 @@ def _register_tools(self): """ raise NotImplementedError("Subclasses must implement _register_tools") - def _format_tools_for_prompt(self) -> str: - """Format the registered tools into a string for the prompt.""" + def _format_tools_for_prompt( + self, registry: Optional[Dict[str, dict]] = None + ) -> str: + """Format the registered tools into a string for the prompt. + + Parameters + ---------- + registry: + If provided, use this dict of tools instead of the global + ``_TOOL_REGISTRY``. ``ToolLoader.resolve()`` returns such a + filtered dict each turn. + """ + source = registry if registry is not None else _TOOL_REGISTRY tool_descriptions = [] - for name, tool_info in _TOOL_REGISTRY.items(): + for name, tool_info in source.items(): params_str = ", ".join( [ f"{param_name}{'' if param_info['required'] else '?'}: {param_info['type']}" @@ -460,6 +483,20 @@ def rebuild_system_prompt(self) -> None: # mixin prompts, tool descriptions, and response format are all included. self._system_prompt_cache = self._compose_system_prompt() + def resolve_tools_for_turn(self, user_message: str) -> None: + """Re-evaluate tool bundles for the current turn and rebuild the prompt. + + If no ``tool_loader`` is configured this is a no-op, preserving the + existing behaviour (all registered tools always visible). + + Called at the top of ``process_query`` before the first LLM call so + that keyword-activated bundles can match the current user message. + """ + if self.tool_loader is None: + return + filtered = self.tool_loader.resolve(user_message, _TOOL_REGISTRY) + self._system_prompt_cache = self._compose_system_prompt(tool_registry=filtered) + def list_tools(self, verbose: bool = True) -> None: """ Display all tools registered for this agent with their parameters and descriptions. @@ -1252,6 +1289,9 @@ def _execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Any: try: result = tool(**tool_args) logger.debug(f"Tool execution result: {result}") + # Record usage so bundle-based loader can keep the owning bundle warm + if self.tool_loader is not None: + self.tool_loader.record_tool_use(tool_name) return result except subprocess.TimeoutExpired as e: # Handle subprocess timeout specifically @@ -1588,6 +1628,11 @@ def process_query( # Store query for error context (used in _execute_tool for error formatting) self._current_query = user_input + # Re-evaluate which tool bundles should be visible for this turn. + # Keyword-activated bundles match against user_input; session bundles + # that were used in prior turns stay warm automatically. + self.resolve_tools_for_turn(user_input) + logger.debug(f"Processing query: {user_input}") conversation = [] # Build messages array for chat completions diff --git a/src/gaia/agents/base/tool_loader.py b/src/gaia/agents/base/tool_loader.py new file mode 100644 index 000000000..09fa8079b --- /dev/null +++ b/src/gaia/agents/base/tool_loader.py @@ -0,0 +1,269 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +""" +ToolLoader — bundle-based tool visibility for agents. + +Gates which tools appear in the LLM prompt each turn without changing the +global ``_TOOL_REGISTRY``. The registry remains the source of truth for +*all* registered tools; the loader picks the subset that goes into the +system prompt. + +Bundles +------- +A ``ToolBundle`` groups related tools under a name with an activation +policy. Three policies exist: + +* **always** — included in every prompt (e.g. ``core``). +* **session** — stays active for the rest of the session once any tool + in the bundle has been used (e.g. ``scratchpad`` after ``create_table``). +* **keyword** — activated when the current user message matches one of a + set of trigger patterns (e.g. ``browser`` on URL patterns). + +The loader evaluates bundles in priority order each turn and returns the +set of tool names that should appear in the prompt. +""" + +from __future__ import annotations + +import logging +import re +import time +from dataclasses import dataclass +from enum import Enum +from typing import Dict, FrozenSet, List, Optional, Set + +logger = logging.getLogger(__name__) + + +class ActivationPolicy(Enum): + """How a bundle decides whether to be active.""" + + ALWAYS = "always" + SESSION = "session" # Active once any tool in the bundle was used this session + KEYWORD = "keyword" # Active when user message matches trigger patterns + + +@dataclass(frozen=True) +class ToolBundle: + """An immutable group of tools sharing an activation policy. + + Parameters + ---------- + name: + Human-readable bundle identifier (e.g. ``"rag"``, ``"scratchpad"``). + tools: + Frozenset of tool names that belong to this bundle. + policy: + When the bundle should be included in the prompt. + keywords: + Regex patterns (case-insensitive) checked against the user message + when ``policy`` is ``KEYWORD``. Ignored for other policies. + """ + + name: str + tools: FrozenSet[str] + policy: ActivationPolicy + keywords: FrozenSet[str] = frozenset() + + +@dataclass +class _BundleState: + """Mutable per-session state for a single bundle.""" + + activated: bool = False # True once the bundle has been activated this session + last_used_ts: float = 0.0 # Timestamp of most recent tool use in this bundle + + +class ToolLoader: + """Selects which registered tools appear in the LLM prompt each turn. + + Usage:: + + loader = ToolLoader() + loader.register_bundle(ToolBundle( + name="scratchpad", + tools=frozenset({"create_table", "insert_data", "query_data", + "list_tables", "drop_table"}), + policy=ActivationPolicy.SESSION, + )) + + # Each turn, ask which tools should be visible: + active_tools = loader.resolve(user_message, registry) + + The loader does **not** modify ``_TOOL_REGISTRY``. It returns a + filtered view that the agent uses when building the prompt. + """ + + # Warm-window: if a bundle was used in the last 24 h, keep it active + WARM_WINDOW_SECS: float = 24 * 3600 + + def __init__(self) -> None: + self._bundles: Dict[str, ToolBundle] = {} + self._state: Dict[str, _BundleState] = {} + # History of (tool_name, timestamp) for logging / warm-window checks + self._tool_history: List[tuple[str, float]] = [] + # Reverse index: tool_name → bundle_name for fast lookup + self._tool_to_bundle: Dict[str, str] = {} + + # ── registration ───────────────────────────────────────────────────── + + def register_bundle(self, bundle: ToolBundle) -> None: + """Register a bundle (idempotent — overwrites if name already exists).""" + self._bundles[bundle.name] = bundle + self._state.setdefault(bundle.name, _BundleState()) + for tool_name in bundle.tools: + self._tool_to_bundle[tool_name] = bundle.name + + def register_bundles(self, bundles: list[ToolBundle]) -> None: + for b in bundles: + self.register_bundle(b) + + # ── per-turn resolution ────────────────────────────────────────────── + + def resolve( + self, + user_message: str, + registry: Dict[str, dict], + ) -> Dict[str, dict]: + """Return the subset of *registry* that should appear in the prompt. + + Parameters + ---------- + user_message: + The current user turn (used for keyword matching). + registry: + The full ``_TOOL_REGISTRY`` dict mapping tool names → metadata. + + Returns + ------- + dict + Filtered copy of *registry* containing only active tools. + """ + active_names: Set[str] = set() + activated_bundles: list[str] = [] + + for name, bundle in self._bundles.items(): + state = self._state[name] + + if bundle.policy == ActivationPolicy.ALWAYS: + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + + if bundle.policy == ActivationPolicy.SESSION: + if state.activated: + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + # Warm-window: check if any tool in the bundle was used recently + if self._was_used_recently(bundle): + state.activated = True + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + # Also activate if keywords match (session bundles can have keywords) + if bundle.keywords and self._keywords_match( + bundle.keywords, user_message + ): + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + + if bundle.policy == ActivationPolicy.KEYWORD: + if state.activated: + # Already activated this session — keep warm + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + if bundle.keywords and self._keywords_match( + bundle.keywords, user_message + ): + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + # Warm-window fallback + if self._was_used_recently(bundle): + active_names.update(bundle.tools) + activated_bundles.append(name) + continue + + # Include any registered tools that are NOT in any bundle (backwards compat). + unbundled = {t for t in registry if t not in self._tool_to_bundle} + active_names.update(unbundled) + + logger.debug( + "ToolLoader resolved %d/%d tools (bundles: %s)", + len(active_names & set(registry)), + len(registry), + ", ".join(activated_bundles) or "none", + ) + + return {k: v for k, v in registry.items() if k in active_names} + + # ── tool execution hook ────────────────────────────────────────────── + + def record_tool_use(self, tool_name: str) -> None: + """Record that a tool was executed (called from ``_execute_tool``). + + This flips the owning bundle's ``activated`` flag so session-policy + bundles stay warm for the rest of the conversation. + """ + now = time.time() + self._tool_history.append((tool_name, now)) + bundle_name = self._tool_to_bundle.get(tool_name) + if bundle_name and bundle_name in self._state: + self._state[bundle_name].activated = True + self._state[bundle_name].last_used_ts = now + + # ── query helpers ──────────────────────────────────────────────────── + + def get_active_bundle_names(self) -> list[str]: + """Return names of currently activated bundles.""" + return [n for n, s in self._state.items() if s.activated] + + def get_bundle_for_tool(self, tool_name: str) -> Optional[str]: + """Return the bundle name that owns *tool_name*, or ``None``.""" + return self._tool_to_bundle.get(tool_name) + + def force_activate(self, bundle_name: str) -> None: + """Force-activate a bundle for the current session. + + This is a public API for callers that need to mark a bundle active + without reaching into ``_state`` directly. + """ + if bundle_name not in self._state: + raise KeyError(f"Unknown bundle: {bundle_name}") + now = time.time() + self._state[bundle_name].activated = True + self._state[bundle_name].last_used_ts = now + + def reset_session(self) -> None: + """Clear per-session state (call between conversations).""" + for state in self._state.values(): + state.activated = False + state.last_used_ts = 0.0 + self._tool_history.clear() + + # ── internals ──────────────────────────────────────────────────────── + + def _keywords_match(self, keywords: FrozenSet[str], message: str) -> bool: + """Return True if any keyword pattern matches *message*.""" + for pattern in keywords: + try: + if re.search(pattern, message, re.IGNORECASE): + return True + except re.error: + # Treat bad regex as a plain substring match + if pattern.lower() in message.lower(): + return True + return False + + def _was_used_recently(self, bundle: ToolBundle) -> bool: + """Check if any tool in *bundle* was used within the warm window.""" + cutoff = time.time() - self.WARM_WINDOW_SECS + for tool_name, ts in reversed(self._tool_history): + if ts < cutoff: + break + if tool_name in bundle.tools: + return True + return False diff --git a/src/gaia/agents/chat/agent.py b/src/gaia/agents/chat/agent.py index a15725106..f71d49db4 100644 --- a/src/gaia/agents/chat/agent.py +++ b/src/gaia/agents/chat/agent.py @@ -1633,6 +1633,11 @@ def text_to_speech( except Exception as _mcp_err: logger.warning("MCP server load failed: %s", _mcp_err) + # Set up bundle-based tool loading (#688 Phase 1). + # All tools remain registered in _TOOL_REGISTRY; the loader decides + # which subset appears in the LLM prompt each turn. + self._setup_tool_bundles() + # NOTE: The actual tool definitions are in the mixin classes: # - RAGToolsMixin (rag_tools.py): RAG and document indexing tools # - FileToolsMixin (file_tools.py): Directory monitoring @@ -1704,6 +1709,247 @@ def search_web(query: str) -> dict: f" search_web={'registered' if has_perplexity else 'skipped (no PERPLEXITY_API_KEY)'}" ) + def _setup_tool_bundles(self) -> None: + """Configure bundle-based tool loading for ChatAgent (#688 Phase 1). + + Groups the ~30+ registered tools into semantic bundles so the + ToolLoader can gate which ones appear in the LLM prompt each turn. + Tools that are not assigned to any bundle remain always-visible + (backwards compatible). + """ + from gaia.agents.base.tool_loader import ( + ActivationPolicy, + ToolBundle, + ToolLoader, + ) + from gaia.agents.base.tools import _TOOL_REGISTRY + + self.tool_loader = ToolLoader() + + # ── always-on: core tools the LLM needs every turn ────────────── + self.tool_loader.register_bundle( + ToolBundle( + name="core", + tools=frozenset( + { + "read_file", + "list_files", + "run_shell_command", + "get_system_info", + } + ), + policy=ActivationPolicy.ALWAYS, + ) + ) + + # ── RAG bundle: active when docs are indexed or user asks about docs ─ + self.tool_loader.register_bundle( + ToolBundle( + name="rag", + tools=frozenset( + { + "query_documents", + "query_specific_file", + "search_indexed_chunks", + "evaluate_retrieval", + "index_document", + "list_indexed_documents", + "rag_status", + "summarize_document", + "dump_document", + "index_directory", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"document|pdf|index|search.*file|rag|chunk|summarize|summary", + r"what does .+ say", + r"find.*in .+(report|handbook|manual|doc)", + } + ), + ) + ) + + # Pre-activate RAG bundle when documents are already indexed + if ( + hasattr(self, "rag") + and self.rag + and getattr(self.rag, "indexed_files", None) + ): + # Use public API to avoid touching internal _state + try: + self.tool_loader.force_activate("rag") + except Exception: + # Fallback to direct state update for very old loaders + if "rag" in getattr(self.tool_loader, "_state", {}): + self.tool_loader._state["rag"].activated = True + + # ── filesystem: file navigation and search ────────────────────── + self.tool_loader.register_bundle( + ToolBundle( + name="filesystem", + tools=frozenset( + { + "search_file", + "search_directory", + "search_file_content", + "browse_directory", + "get_file_info", + "list_recent_files", + "add_watch_directory", + "write_file", + "edit_file", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"\b(?:file|folder|directory|path|find|search|browse|tree|ls|dir)\b", + r"\.[a-z]{1,5}\b", # file extensions like .py, .json + r"\b(?:[A-Za-z]:\\|/)", # absolute or drive-letter paths + r"\b(?:write|edit|save|create).*file|\bmodify\b", + } + ), + ) + ) + + # ── browser: web fetching and URL tools ───────────────────────── + self.tool_loader.register_bundle( + ToolBundle( + name="browser", + tools=frozenset( + { + "open_url", + "fetch_webpage", + "search_web", + "search_documentation", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"\bhttps?://", + r"\b(?:url|website|webpage|web\s*page|browse|internet)\b", + r"\bsearch\s+(?:the\s+)?web\b|\bgoogle\b|\blook\s+up\b", + } + ), + ) + ) + + # ── data: data analysis and execution ─────────────────────────── + self.tool_loader.register_bundle( + ToolBundle( + name="data", + tools=frozenset( + { + "analyze_data_file", + "execute_python_file", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"csv|excel|xlsx|data|analy[sz]e|statistics|chart|plot", + r"run.*python|execute.*script|\.py\b", + } + ), + ) + ) + + # ── desktop: system interaction tools ─────────────────────────── + self.tool_loader.register_bundle( + ToolBundle( + name="desktop", + tools=frozenset( + { + "take_screenshot", + "list_windows", + "read_clipboard", + "write_clipboard", + "notify_desktop", + "text_to_speech", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"screenshot|screen|capture|window|clipboard|copy|paste", + r"notify|notification|alert|speak|say|tts|voice|read.*aloud", + } + ), + ) + ) + + # ── vlm: vision tools (only if registered) ───────────────────── + vlm_tools = {"analyze_image", "answer_question_about_image"} & set( + _TOOL_REGISTRY + ) + if vlm_tools: + self.tool_loader.register_bundle( + ToolBundle( + name="vlm", + tools=frozenset(vlm_tools), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"image|photo|picture|screenshot|jpg|png|gif|visual", + r"what.*see|describe.*image|look.*at", + } + ), + ) + ) + + # ── sd: image generation (only if registered) ─────────────────── + sd_tools = { + "generate_image", + "list_sd_models", + "get_generation_history", + } & set(_TOOL_REGISTRY) + if sd_tools: + self.tool_loader.register_bundle( + ToolBundle( + name="sd", + tools=frozenset(sd_tools), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset( + { + r"generate.*image|create.*image|draw|stable.?diffusion|sdxl", + r"art|illustration|render|picture.*of", + } + ), + ) + ) + + # ── mcp_*: one bundle per MCP server ──────────────────────────── + # MCP tools are prefixed as mcp__. Group by server. + mcp_tools_by_server: Dict[str, set] = {} + for tool_name in _TOOL_REGISTRY: + if tool_name.startswith("mcp_"): + parts = tool_name.split("_", 2) # mcp, server, rest + if len(parts) >= 3: + server = parts[1] + mcp_tools_by_server.setdefault(server, set()).add(tool_name) + + for server, tools in mcp_tools_by_server.items(): + self.tool_loader.register_bundle( + ToolBundle( + name=f"mcp_{server}", + tools=frozenset(tools), + policy=ActivationPolicy.SESSION, + ) + ) + + # Log the bundle configuration + total = len(_TOOL_REGISTRY) + bundled = sum(len(b.tools) for b in self.tool_loader._bundles.values()) + logger.info( + "ToolLoader configured: %d bundles, %d/%d tools bundled", + len(self.tool_loader._bundles), + bundled, + total, + ) + def _index_documents(self, documents: List[str]) -> None: """Index initial documents.""" for doc in documents: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..43c1ccca5 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Make tests a package so `import tests.*` works in CI diff --git a/tests/fixtures/email/conftest.py b/tests/fixtures/email/conftest.py new file mode 100644 index 000000000..9f06d4f80 --- /dev/null +++ b/tests/fixtures/email/conftest.py @@ -0,0 +1,118 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +import mailbox +from pathlib import Path +from typing import Any, Callable + +import pytest + +from tests.fixtures.email.generate_mbox import generate + +EMAIL_FIXTURE_DIR = Path(__file__).resolve().parent +CHECKIN_MBOX = EMAIL_FIXTURE_DIR / "synthetic_inbox.mbox" +CHECKIN_GT = EMAIL_FIXTURE_DIR / "ground_truth.json" + + +def _ensure_generated_to(out_mbox: Path, out_gt: Path) -> None: + """Generate fixtures deterministically into the provided paths.""" + if out_mbox.exists() and out_gt.exists(): + return + generate(out_mbox, out_gt) + + +@pytest.fixture(scope="session") +def synthetic_mbox(tmp_path_factory) -> mailbox.mbox: + """Load the synthetic mbox fixture. + + If checked-in fixtures exist under the fixtures directory, use them. + Otherwise generate deterministic fixtures into a temporary directory so + test runs do not mutate the repository working tree. + """ + if CHECKIN_MBOX.exists() and CHECKIN_GT.exists(): + return mailbox.mbox(str(CHECKIN_MBOX), create=False) + + td = tmp_path_factory.mktemp("email_fixtures") + out_mbox = td / "synthetic_inbox.mbox" + out_gt = td / "ground_truth.json" + _ensure_generated_to(out_mbox, out_gt) + return mailbox.mbox(str(out_mbox), create=False) + + +@pytest.fixture(scope="session") +def email_ground_truth(tmp_path_factory) -> dict[str, dict[str, Any]]: + """Load Message-ID keyed ground truth metadata. + + Mirrors the generation strategy used by ``synthetic_mbox`` so the + mbox and ground-truth remain paired when generated into a temp dir. + """ + if CHECKIN_MBOX.exists() and CHECKIN_GT.exists(): + return json.loads(CHECKIN_GT.read_text(encoding="utf-8")) + + td = tmp_path_factory.mktemp("email_fixtures") + out_mbox = td / "synthetic_inbox.mbox" + out_gt = td / "ground_truth.json" + _ensure_generated_to(out_mbox, out_gt) + return json.loads(out_gt.read_text(encoding="utf-8")) + + +@pytest.fixture(scope="session") +def _messages_with_ids( + synthetic_mbox: mailbox.mbox, +) -> list[tuple[str, Any]]: + rows: list[tuple[str, Any]] = [] + for msg in synthetic_mbox: + msg_id = msg.get("Message-ID") + if msg_id: + rows.append((msg_id, msg)) + return rows + + +@pytest.fixture() +def single_email( + _messages_with_ids: list[tuple[str, Any]], + email_ground_truth: dict[str, dict[str, Any]], +) -> Callable[[str], Any]: + """Return a callable that fetches one message by triage category.""" + + def _get(category: str) -> Any: + for msg_id, msg in _messages_with_ids: + meta = email_ground_truth.get(msg_id) + if meta and meta.get("category") == category: + return msg + raise KeyError(f"No email found for category: {category}") + + return _get + + +@pytest.fixture() +def spam_emails( + _messages_with_ids: list[tuple[str, Any]], + email_ground_truth: dict[str, dict[str, Any]], +) -> list[Any]: + """Return all spam/phishing emails for filter testing.""" + out = [] + for msg_id, msg in _messages_with_ids: + meta = email_ground_truth.get(msg_id) + if not meta: + continue + if meta.get("is_spam") or meta.get("is_phishing"): + out.append(msg) + return out + + +@pytest.fixture() +def ambiguous_emails( + _messages_with_ids: list[tuple[str, Any]], + email_ground_truth: dict[str, dict[str, Any]], +) -> list[Any]: + """Return ambiguous/borderline emails for decision-boundary tests.""" + out = [] + for msg_id, msg in _messages_with_ids: + meta = email_ground_truth.get(msg_id) + if meta and meta.get("ambiguous"): + out.append(msg) + return out diff --git a/tests/fixtures/email/generate_mbox.py b/tests/fixtures/email/generate_mbox.py new file mode 100644 index 000000000..b990c8559 --- /dev/null +++ b/tests/fixtures/email/generate_mbox.py @@ -0,0 +1,1114 @@ +#!/usr/bin/env python3 +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +"""Generate deterministic synthetic .mbox fixtures for email triage tests. + +This script creates: +- tests/fixtures/email/synthetic_inbox.mbox +- tests/fixtures/email/ground_truth.json + +The dataset is fully synthetic (RFC 2606 domains) and deterministic. +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import mailbox +import random +import re +import tempfile +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from email import policy +from email.message import EmailMessage +from email.utils import format_datetime +from pathlib import Path +from typing import Any + +SEED = 23023 +TOTAL_MESSAGES = 220 + +OUT_DIR = Path(__file__).resolve().parent +OUT_MBOX = OUT_DIR / "synthetic_inbox.mbox" +OUT_GT = OUT_DIR / "ground_truth.json" + +CATEGORIES = ["urgent", "actionable", "informational", "low_priority"] + +TARGET_COUNTS = { + "urgent": 24, + "actionable": 51, + "informational": 66, + "low_priority": 37, + "spam": 20, + "ambiguous": 15, + "malformed": 7, +} + +if sum(TARGET_COUNTS.values()) != TOTAL_MESSAGES: + raise ValueError("Target counts do not sum to TOTAL_MESSAGES") + + +@dataclass(frozen=True) +class Persona: + key: str + display_name: str + address: str + role: str + + +PERSONAS = { + "sarah_chen": Persona( + "sarah_chen", + "Sarah Chen", + "sarah.chen@acme-corp.example.com", + "VP Engineering", + ), + "alex_kumar": Persona( + "alex_kumar", + "Alex Kumar", + "alex.kumar@acme-corp.example.com", + "Senior Engineer", + ), + "jordan_lee": Persona( + "jordan_lee", + "Jordan Lee", + "jordan.lee@acme-corp.example.com", + "Product Manager", + ), + "it_systems": Persona( + "it_systems", + "IT Systems", + "noreply@acme-corp.example.com", + "Automated", + ), + "hr_team": Persona( + "hr_team", + "HR Team", + "hr@acme-corp.example.com", + "Automated", + ), + "maria_santos": Persona( + "maria_santos", + "Maria Santos", + "maria.santos@globaltech.example.net", + "External partner", + ), + "devops_bot": Persona( + "devops_bot", + "DevOps Bot", + "alerts@acme-corp.example.com", + "CI/CD alerts", + ), + "newsletter_tech": Persona( + "newsletter_tech", + "Tech Insider Weekly", + "digest@tech-insider.example.com", + "Newsletter", + ), + "newsletter_market": Persona( + "newsletter_market", + "Market Pulse", + "news@marketpulse.example.com", + "Newsletter", + ), + "marketing_vendor": Persona( + "marketing_vendor", + "GrowthStack Solutions", + "hello@growthstack.example.com", + "Cold outreach", + ), +} + +PRIMARY_TO = "Taylor Morgan " +PRIMARY_CC = "Eng Leadership " + +CORP_DISCLAIMER = ( + "This email and any attachments are confidential and intended only for the " + "named recipient. If you received this in error, please notify the sender " + "and delete this message." +) + +RECEIVED_TEMPLATE = ( + "from smtp-{hop}.acme-corp.example.com (smtp-{hop}.acme-corp.example.com " + "[192.0.2.{octet}]) by mx-{next_hop}.acme-corp.example.com with ESMTPS id {rid}; " + "{date}" +) + + +class IdFactory: + def __init__(self, rng: random.Random) -> None: + self.rng = rng + self.counter = 0 + + def make(self, domain: str = "mail.acme-corp.example.com") -> str: + self.counter += 1 + left = f"msg{self.counter:04d}.{self.rng.randrange(100000, 999999)}" + return f"<{left}@{domain}>" + + +def _sha256(path: Path) -> str: + return hashlib.sha256(path.read_bytes()).hexdigest() + + +def _weekday_weighted_datetimes(rng: random.Random, count: int) -> list[datetime]: + base = datetime(2026, 3, 2, 8, 0, tzinfo=timezone.utc) # Monday + values: list[datetime] = [] + for i in range(count): + # Spread across two weeks with weekday clustering around 9 AM and 4-5 PM. + day_offset = rng.randint(0, 13) + d = base + timedelta(days=day_offset) + if d.weekday() < 5: + slot = rng.choices([9, 16, 17, 11], weights=[45, 28, 20, 7], k=1)[0] + minute = rng.randint(0, 59) + else: + # Weekend batch (overnight triage scenario) + slot = rng.choices([1, 2, 3, 4], weights=[20, 35, 30, 15], k=1)[0] + minute = rng.randint(0, 59) + values.append(d.replace(hour=slot, minute=minute)) + # Deterministic but not strictly sorted; inbox can arrive somewhat shuffled. + rng.shuffle(values) + # Nudge a batch to Monday pre-work hours from weekend. + for i in range(min(12, len(values))): + if i % 3 == 0: + values[i] = values[i].replace(hour=5, minute=rng.randint(0, 59)) + return values + + +def _received_headers(dt: datetime, rng: random.Random) -> list[str]: + items = [] + for hop in range(1, 3): + items.append( + RECEIVED_TEMPLATE.format( + hop=hop, + next_hop=hop + 1, + octet=20 + hop, + rid=f"R{rng.randrange(10000, 99999)}", + date=format_datetime(dt - timedelta(minutes=hop * 2)), + ) + ) + return items + + +def _make_attachment( + name: str, + size: int, + mime: tuple[str, str], + rng: random.Random, +) -> tuple[str, bytes, str, str]: + body = bytes(rng.randrange(0, 255) for _ in range(size)) + maintype, subtype = mime + return name, body, maintype, subtype + + +def _base_message( + *, + persona: Persona, + subject: str, + body_text: str, + date_value: datetime, + message_id: str, + category: str, + sender_persona: str, + rng: random.Random, + to: str = PRIMARY_TO, + cc: str | None = None, + html_body: str | None = None, + x_priority: str | None = None, + importance: str | None = None, + x_mailer: str | None = None, + reply_to: str | None = None, + list_unsubscribe: str | None = None, + in_reply_to: str | None = None, + references: list[str] | None = None, + has_inline_image: bool = False, + attachments: list[tuple[str, bytes, str, str, str]] | None = None, + forwarded_raw: str | None = None, +) -> tuple[EmailMessage, dict[str, Any]]: + msg = EmailMessage(policy=policy.SMTP) + msg["From"] = f"{persona.display_name} <{persona.address}>" + msg["To"] = to + if cc: + msg["Cc"] = cc + msg["Subject"] = subject + msg["Date"] = format_datetime(date_value) + msg["Message-ID"] = message_id + if in_reply_to: + msg["In-Reply-To"] = in_reply_to + if references: + msg["References"] = " ".join(references) + if x_priority: + msg["X-Priority"] = x_priority + if importance: + msg["Importance"] = importance + if x_mailer: + msg["X-Mailer"] = x_mailer + if reply_to: + msg["Reply-To"] = reply_to + if list_unsubscribe: + msg["List-Unsubscribe"] = list_unsubscribe + + for rec in _received_headers(date_value, rng): + msg["Received"] = rec + + text = f"{body_text}\n\n{CORP_DISCLAIMER}\n" + if html_body: + html = ( + "" + f"
{html_body}
" + "
" + f"{CORP_DISCLAIMER}" + "" + ) + msg.set_content(text) + msg.add_alternative(html, subtype="html") + else: + msg.set_content(text) + + if has_inline_image: + if msg.get_content_maintype() != "multipart": + msg.make_mixed() + inline_bytes = bytes((i % 255 for i in range(1024))) + msg.add_attachment( + inline_bytes, + maintype="image", + subtype="png", + filename="logo-inline.png", + disposition="inline", + cid="", + ) + + if attachments: + for filename, payload, mt, st, disposition in attachments: + msg.add_attachment( + payload, + maintype=mt, + subtype=st, + filename=filename, + disposition=disposition, + ) + + if forwarded_raw: + if msg.get_content_maintype() != "multipart": + msg.make_mixed() + forwarded = EmailMessage(policy=policy.SMTP) + forwarded.set_content(forwarded_raw) + forwarded["Subject"] = "Fwd payload" + forwarded["From"] = "legacy.sender@example.org" + forwarded["To"] = PRIMARY_TO + forwarded["Date"] = format_datetime(date_value - timedelta(days=1)) + msg.attach(forwarded) + + meta = { + "category": category, + "priority": "high" if category == "urgent" else "normal", + "is_thread_root": in_reply_to is None, + "thread_id": ( + message_id + if in_reply_to is None + else (references[0] if references else in_reply_to) + ), + "has_attachment": bool(attachments or has_inline_image or forwarded_raw), + "is_spam": False, + "is_phishing": False, + "ambiguous": False, + "rationale": "", + "sender_persona": sender_persona, + } + return msg, meta + + +def _make_thread( + *, + root_subject: str, + persona_keys: list[str], + depth: int, + date_values: list[datetime], + id_factory: IdFactory, + rng: random.Random, + category: str, + ambiguous: bool = False, + rationale: str = "", +) -> list[tuple[EmailMessage, dict[str, Any], str]]: + messages: list[tuple[EmailMessage, dict[str, Any], str]] = [] + refs: list[str] = [] + root_id = id_factory.make() + for i in range(depth): + persona_key = persona_keys[i % len(persona_keys)] + persona = PERSONAS[persona_key] + message_id = root_id if i == 0 else id_factory.make() + refs_local = refs[:] if refs else None + in_reply = refs[-1] if refs else None + subject = root_subject if i == 0 else f"Re: {root_subject}" + text = ( + "Update " + f"{i + 1}/{depth}: please review the latest status and confirm owner.\n" + "Thanks,\n" + f"{persona.display_name}\n\n" + "> Previous thread context included below." + ) + msg, meta = _base_message( + persona=persona, + subject=subject, + body_text=text, + date_value=date_values[i], + message_id=message_id, + category=category, + sender_persona=persona_key, + rng=rng, + cc=PRIMARY_CC, + html_body=( + "

Top-posted update for thread participants.

" + "
> Prior message quoted content
" + ), + in_reply_to=in_reply, + references=refs_local, + x_mailer=rng.choice(["Microsoft Outlook 16.0", "Thunderbird", "Gmail"]), + ) + if ambiguous: + meta["ambiguous"] = True + meta["rationale"] = rationale + if i == 0: + meta["is_thread_root"] = True + meta["thread_id"] = message_id + else: + meta["is_thread_root"] = False + meta["thread_id"] = root_id + messages.append((msg, meta, message_id)) + refs.append(message_id) + return messages + + +def _mailbox_from_records( + records: list[tuple[EmailMessage, dict[str, Any], str]], + malformed_raw: list[tuple[str, dict[str, Any], str]], + out_mbox: Path, + out_gt: Path, +) -> None: + out_mbox.parent.mkdir(parents=True, exist_ok=True) + if out_mbox.exists(): + out_mbox.unlink() + + box = mailbox.mbox(str(out_mbox), create=True) + gt: dict[str, Any] = {} + + for msg, meta, message_id in records: + box.add(msg) + gt[message_id] = meta + + box.flush() + box.close() + + # Append malformed messages directly as raw mbox entries. Use a + # deterministic timestamp on the From_ line so generated mbox files are + # reproducible across runs. + deterministic_from_time = datetime(2026, 3, 2, 0, 0, tzinfo=timezone.utc) + with out_mbox.open("ab") as f: + for raw_msg, meta, message_id in malformed_raw: + from_line = ( + "From malformed@example.com " f"{deterministic_from_time.ctime()}\n" + ) + payload = raw_msg.replace("\n", "\n").encode("utf-8", errors="replace") + if not payload.endswith(b"\n"): + payload += b"\n" + f.write(from_line.encode("utf-8")) + f.write(payload) + f.write(b"\n") + gt[message_id] = meta + + out_gt.write_text(json.dumps(gt, indent=2, sort_keys=True), encoding="utf-8") + + +def _build_dataset( + seed: int = SEED, +) -> tuple[ + list[tuple[EmailMessage, dict[str, Any], str]], + list[tuple[str, dict[str, Any], str]], +]: + rng = random.Random(seed) + id_factory = IdFactory(rng) + date_pool = _weekday_weighted_datetimes(rng, TOTAL_MESSAGES) + date_idx = 0 + + def next_date() -> datetime: + nonlocal date_idx + d = date_pool[date_idx] + date_idx += 1 + return d + + records: list[tuple[EmailMessage, dict[str, Any], str]] = [] + malformed: list[tuple[str, dict[str, Any], str]] = [] + + # Ensure recurring personas appear between 3-8 times. + persona_usage = {k: 0 for k in PERSONAS} + + def pick_persona(keys: list[str]) -> str: + key = rng.choice(keys) + persona_usage[key] += 1 + return key + + # Build threaded corp conversations (3-5 deep). + thread_specs = [ + ( + "Prod incident follow-up", + ["sarah_chen", "alex_kumar", "devops_bot"], + 5, + "urgent", + ), + ( + "Q2 contract redlines", + ["maria_santos", "sarah_chen", "jordan_lee"], + 4, + "urgent", + ), + ( + "Roadmap dependency sync", + ["jordan_lee", "alex_kumar", "sarah_chen"], + 4, + "actionable", + ), + ( + "Security advisory triage", + ["it_systems", "alex_kumar", "devops_bot"], + 3, + "urgent", + ), + ( + "Audit evidence request", + ["hr_team", "jordan_lee", "alex_kumar"], + 3, + "actionable", + ), + ] + for subject, pkeys, depth, cat in thread_specs: + for msg, meta, mid in _make_thread( + root_subject=subject, + persona_keys=pkeys, + depth=depth, + date_values=[next_date() for _ in range(depth)], + id_factory=id_factory, + rng=rng, + category=cat, + ): + persona_usage[meta["sender_persona"]] += 1 + records.append((msg, meta, mid)) + + # Utility to add single message. + def add_single( + *, + persona_key: str, + subject: str, + body: str, + category: str, + spam: bool = False, + phishing: bool = False, + ambiguous: bool = False, + rationale: str = "", + html: bool = False, + x_priority: str | None = None, + importance: str | None = None, + list_unsub: bool = False, + attachments: list[tuple[str, bytes, str, str, str]] | None = None, + inline_image: bool = False, + reply_to: str | None = None, + x_mailer: str | None = None, + forwarded: bool = False, + to: str = PRIMARY_TO, + cc: str | None = None, + ) -> None: + persona_usage[persona_key] += 1 + message_id = id_factory.make() + msg, meta = _base_message( + persona=PERSONAS[persona_key], + subject=subject, + body_text=body, + date_value=next_date(), + message_id=message_id, + category=category, + sender_persona=persona_key, + rng=rng, + cc=cc, + to=to, + html_body=( + f'

{body}

logo

' + if html + else None + ), + x_priority=x_priority, + importance=importance, + x_mailer=x_mailer, + reply_to=reply_to, + list_unsubscribe=( + ", " + if list_unsub + else None + ), + has_inline_image=inline_image, + attachments=attachments, + forwarded_raw=( + "-----Forwarded Message-----\nLegacy Outlook payload" + if forwarded + else None + ), + ) + meta["is_spam"] = spam + meta["is_phishing"] = phishing + meta["ambiguous"] = ambiguous + meta["rationale"] = rationale + if spam: + meta["priority"] = "low" + records.append((msg, meta, message_id)) + + # Create content pools by category. + corp_templates = { + "urgent": [ + "[SEV1] API latency above SLA - owner needed", + "Client deadline: contract signature required by EOD", + "Security advisory: rotate credentials within 4 hours", + "Prod incident report requires exec review", + "Compliance acknowledgment due by EOD", + ], + "actionable": [ + "Please review PR #4821 by tomorrow", + "Can you approve expense report TR-2288?", + "Need your decision on vendor shortlist", + "Meeting invite: launch readiness review", + "JIRA ticket assigned: GAIA-2024", + ], + "informational": [ + "VPN maintenance window this Saturday", + "Benefits enrollment reminder", + "All-hands recap and recording", + "Confluence page updated: onboarding checklist", + "Shipping confirmation for office equipment", + "Quarterly financial digest", + ], + "low_priority": [ + "Try our premium analytics package", + "Top 10 growth hacks for your startup", + "You were mentioned in a social thread", + "Special discount expires tonight", + "Weekly promo digest", + ], + } + + attachment_specs = [ + ("report.pdf", ("application", "pdf")), + ("costs.csv", ("text", "csv")), + ("diagram.png", ("image", "png")), + ( + "brief.docx", + ( + "application", + "vnd.openxmlformats-officedocument.wordprocessingml.document", + ), + ), + ("invite.ics", ("text", "calendar")), + ] + + # Generate regular corp/personal mail such that total stays exactly 220. + # Breakdown: + # - threaded corp messages: len(records) (already added) + # - regular messages: 129 + # - personal/consumer messages: 30 + # - spam/phishing: 20 + # - ambiguous: 15 + # - malformed: 7 + regular_count = 129 + + corp_personas = [ + "sarah_chen", + "alex_kumar", + "jordan_lee", + "it_systems", + "hr_team", + "maria_santos", + "devops_bot", + "newsletter_tech", + "newsletter_market", + "marketing_vendor", + ] + + regular_category_weights = [ + ("urgent", 20), + ("actionable", 34), + ("informational", 53), + ("low_priority", 22), + ] + + for i in range(regular_count): + cat = rng.choices( + [c for c, _ in regular_category_weights], + weights=[w for _, w in regular_category_weights], + k=1, + )[0] + persona_key = pick_persona(corp_personas) + template = rng.choice(corp_templates[cat]) + topic = rng.choice( + [ + "IT maintenance", + "expense approval", + "meeting update", + "policy notice", + "build status", + "client follow-up", + "receipt", + "newsletter", + ] + ) + subject = f"{template} - {topic}" + body = ( + f"Hello Taylor,\n\n{template}." + " Please review details and reply if needed." + ) + attachment_template = rng.choice(attachment_specs) + add_single( + persona_key=persona_key, + subject=subject, + body=body, + category=cat, + html=rng.random() < 0.4, + x_priority=( + "1 (Highest)" if cat == "urgent" and rng.random() < 0.6 else None + ), + importance=("High" if cat == "urgent" and rng.random() < 0.6 else None), + list_unsub=( + cat in {"informational", "low_priority"} and rng.random() < 0.5 + ), + attachments=( + [ + ( + name, + payload, + mt, + st, + "attachment", + ) + for name, payload, mt, st in [ + _make_attachment( + name=attachment_template[0], + size=rng.randint(1024, 4096), + mime=attachment_template[1], + rng=rng, + ) + ] + ] + if rng.random() < 0.25 + else None + ), + inline_image=rng.random() < 0.08, + forwarded=rng.random() < 0.06, + x_mailer=rng.choice( + [ + "Microsoft Outlook 16.0", + "Gmail", + "Thunderbird", + "Jira Mailer", + "PagerDuty", + ] + ), + cc=(PRIMARY_CC if rng.random() < 0.35 else None), + ) + + # Personal / consumer messages (~30) blended into informational/ + # low-priority/actionable buckets. + personal_subjects = [ + ("Your Amazon order has shipped", "informational"), + ("Flight confirmation ACM-8722", "actionable"), + ("Bank alert: transaction posted", "informational"), + ("LinkedIn: 3 new profile views", "low_priority"), + ("Hotel booking confirmation", "informational"), + ("GitHub: 12 new stars this week", "informational"), + ("Calendar reminder: dentist appointment", "actionable"), + ("Retail receipt for your purchase", "informational"), + ] + for i in range(30): + subj, cat = personal_subjects[i % len(personal_subjects)] + add_single( + persona_key=rng.choice( + ["newsletter_tech", "newsletter_market", "marketing_vendor"] + ), + subject=f"{subj} [{i + 1}]", + body="Automated consumer notification for synthetic triage dataset.", + category=cat, + list_unsub=(cat in {"informational", "low_priority"}), + html=rng.random() < 0.55, + x_mailer=rng.choice(["Gmail", "Outlook.com", "MailChimp"]), + ) + + # Spam / phishing (pre-triage filter). + spam_templates = [ + ( + "URGENT: verify account password immediately", + True, + True, + "IT Support ", + "verify your account to avoid deactivation", + "security-team@acme-corp.example.com", + "invoice-update.pdf.exe", + ), + ( + "Wire transfer needed now - from CEO", + True, + True, + "Sarah Chen ", + "send funds before 30 minutes", + "payments@acme-corp.example.com", + "wire_instructions.doc", + ), + ( + "Package held at customs - click to release", + True, + False, + "FedEx Notice ", + "tracking update requires payment", + None, + "tracking_label.zip", + ), + ( + "You won the crypto lottery", + True, + False, + "Claims Desk ", + "confirm wallet and private key", + None, + "claim_form.pdf", + ), + ] + for i in range(TARGET_COUNTS["spam"]): + ( + subj, + spam_flag, + phishing_flag, + from_display, + body_hint, + reply_to, + bad_attachment, + ) = spam_templates[i % len(spam_templates)] + message_id = id_factory.make("mail.suspicious.example.org") + dt = next_date() + msg = EmailMessage(policy=policy.SMTP) + msg["From"] = from_display + msg["To"] = PRIMARY_TO + msg["Subject"] = f"{subj} #{i + 1}" + msg["Date"] = format_datetime(dt) + msg["Message-ID"] = message_id + msg["X-Mailer"] = rng.choice(["PHPMailer", "Roundcube", "Unknown MTA"]) + if reply_to: + msg["Reply-To"] = reply_to + msg.set_content( + f"Dear user, {body_hint}. immediate action needed!!! please do not delay." + ) + att_name, payload, mt, st = _make_attachment( + bad_attachment, + size=rng.randint(1024, 3072), + mime=("application", "octet-stream"), + rng=rng, + ) + msg.add_attachment( + payload, + maintype=mt, + subtype=st, + filename=att_name, + disposition="attachment", + ) + meta = { + "category": rng.choice(["informational", "low_priority"]), + "priority": "low", + "is_thread_root": True, + "thread_id": message_id, + "has_attachment": True, + "is_spam": spam_flag, + "is_phishing": phishing_flag, + "ambiguous": False, + "rationale": "", + "sender_persona": "spam_unknown", + } + records.append((msg, meta, message_id)) + + # Ambiguous / borderline. + ambiguous_pool = [ + ( + "Meeting invite from unknown external contact", + "actionable", + "Invite could be relevant partnership kickoff but sender is unknown.", + ), + ( + "Vendor invoice with no prior relationship", + "informational", + "Could be fraud; triage as informational pending verification.", + ), + ( + "Automated JIRA ticket for unfamiliar project", + "informational", + "No ownership signal; likely informational for this user.", + ), + ( + "Newsletter from tool you actively use", + "informational", + "Product updates can impact active workflows.", + ), + ( + "Quick question from unknown sender", + "low_priority", + "No context and no direct urgency indicators.", + ), + ( + "Compliance notice requiring acknowledgement by EOD", + "urgent", + "Hard deadline suggests urgency, though action is lightweight.", + ), + ( + "Reply-all where user only CC'd", + "informational", + "Likely FYI unless directly asked a question.", + ), + ] + for i in range(TARGET_COUNTS["ambiguous"]): + subject, cat, rationale = ambiguous_pool[i % len(ambiguous_pool)] + add_single( + persona_key=rng.choice( + ["jordan_lee", "maria_santos", "hr_team", "newsletter_tech"] + ), + subject=f"[Ambiguous] {subject} ({i + 1})", + body="Boundary-case message intentionally designed for triage calibration.", + category=cat, + ambiguous=True, + rationale=rationale, + html=rng.random() < 0.3, + x_mailer=rng.choice(["Microsoft Outlook 16.0", "Gmail", "Zendesk Mailer"]), + ) + + # Malformed / parser edge cases (raw entries). + malformed_specs = [ + ( + "missing-subject", + "From: IT Systems \n" + f"To: {PRIMARY_TO}\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "\n" + "Body with missing subject header.\n", + {"category": "informational", "sender_persona": "it_systems"}, + ), + ( + "empty-body", + "From: HR Team \n" + f"To: {PRIMARY_TO}\n" + "Subject: Empty body test\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "\n", + {"category": "informational", "sender_persona": "hr_team"}, + ), + ( + "truncated-multipart", + "From: Jordan Lee \n" + f"To: {PRIMARY_TO}\n" + "Subject: Truncated MIME boundary\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "MIME-Version: 1.0\n" + "Content-Type: multipart/alternative; boundary=XYZBOUND\n" + "\n" + "--XYZBOUND\n" + "Content-Type: text/plain; charset=utf-8\n" + "\n" + "Part one only, boundary never closes properly.\n", + {"category": "actionable", "sender_persona": "jordan_lee"}, + ), + ( + "invalid-date", + "From: DevOps Bot \n" + f"To: {PRIMARY_TO}\n" + "Subject: Invalid Date header example\n" + "Date: not-a-real-date\n" + f"Message-ID: {id_factory.make()}\n" + "\n" + "Invalid date parser edge case.\n", + {"category": "urgent", "sender_persona": "devops_bot"}, + ), + ( + "double-encoded-subject", + "From: Alex Kumar \n" + f"To: {PRIMARY_TO}\n" + "Subject: =?UTF-8?B?PT89VVRGLTg/Qj9VbWx6WldRZ1VIVnlaU0U9Pz0=?=\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "\n" + "Subject intentionally odd for decoder robustness.\n", + {"category": "informational", "sender_persona": "alex_kumar"}, + ), + ( + "bad-base64", + "From: Maria Santos \n" + f"To: {PRIMARY_TO}\n" + "Subject: Base64 body padding issue\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "Content-Transfer-Encoding: base64\n" + "\n" + "SGVsbG8gd29ybGQh====\n", + {"category": "actionable", "sender_persona": "maria_santos"}, + ), + ( + "no-from-long-subject", + f"To: {PRIMARY_TO}\n" + "Subject: " + ("LongSubject-" * 45) + "\n" + f"Date: {format_datetime(next_date())}\n" + f"Message-ID: {id_factory.make()}\n" + "\n" + "No From header and very long subject.\n", + {"category": "low_priority", "sender_persona": "unknown"}, + ), + ] + + for _, raw, extras in malformed_specs: + mid_match = re.search(r"^Message-ID:\s*(<[^>]+>)", raw, re.MULTILINE) + if not mid_match: + raise ValueError("Malformed raw message missing Message-ID") + message_id = mid_match.group(1) + meta = { + "category": extras["category"], + "priority": "normal", + "is_thread_root": True, + "thread_id": message_id, + "has_attachment": False, + "is_spam": False, + "is_phishing": False, + "ambiguous": False, + "rationale": "Malformed parser edge-case message", + "sender_persona": extras["sender_persona"], + } + malformed.append((raw, meta, message_id)) + + # Enforce total count and postconditions. + if len(records) + len(malformed) != TOTAL_MESSAGES: + raise ValueError( + "Generated " + f"{len(records) + len(malformed)} messages, expected {TOTAL_MESSAGES}" + ) + + # Ensure UTF-8 and ISO-8859-1 non-ASCII coverage by replacing one record in-place. + message_id = id_factory.make() + non_ascii_msg, meta = _base_message( + persona=PERSONAS["jordan_lee"], + subject="Résumé update – équipe européenne", + body_text="Voici la mise à jour du résumé du projet pour l'équipe européenne.", + date_value=date_pool[-1], + message_id=message_id, + category="informational", + sender_persona="jordan_lee", + rng=rng, + html_body="

Olá, atualização com caracteres acentuados.

", + ) + # Avoid calling set_charset on multipart messages (Message.get_payload() + # may be a list for multipart), instead set charset on text parts. + if non_ascii_msg.is_multipart(): + for part in non_ascii_msg.walk(): + if part.get_content_maintype() == "text": + try: + part.set_charset("iso-8859-1") + except Exception: + # Be defensive: on some stdlib versions this may raise. + pass + else: + non_ascii_msg.set_charset("iso-8859-1") + # Insert the non-ASCII message by replacing the first informational + # record, keeping the total message counts stable and preserving + # category distribution. + replaced = False + for i, (_msg, _meta, _mid) in enumerate(records): + if _meta.get("category") == "informational": + records[i] = (non_ascii_msg, meta, message_id) + replaced = True + break + if not replaced: + # Fallback: append if no informational slot found (shouldn't happen) + records.append((non_ascii_msg, meta, message_id)) + + # Validate persona repeat ranges for core personas. + core_personas = [ + "sarah_chen", + "alex_kumar", + "jordan_lee", + "it_systems", + "hr_team", + "maria_santos", + "devops_bot", + "newsletter_tech", + "newsletter_market", + ] + for key in core_personas: + if persona_usage.get(key, 0) < 3: + raise ValueError(f"Persona {key} appears fewer than 3 times") + + return records, malformed + + +def generate( + out_mbox: Path = OUT_MBOX, + out_gt: Path = OUT_GT, + seed: int = SEED, +) -> tuple[str, str]: + records, malformed = _build_dataset(seed) + _mailbox_from_records(records, malformed, out_mbox, out_gt) + + mbox_size = out_mbox.stat().st_size + if mbox_size >= 1024 * 1024: + raise ValueError(f"Generated mbox exceeds 1 MB ({mbox_size} bytes)") + + mbox_hash = _sha256(out_mbox) + gt_hash = _sha256(out_gt) + return mbox_hash, gt_hash + + +def verify() -> int: + if not OUT_MBOX.exists() or not OUT_GT.exists(): + print("Missing pre-built fixtures; run without --verify first.") + return 1 + + existing_mbox_hash = _sha256(OUT_MBOX) + existing_gt_hash = _sha256(OUT_GT) + + with tempfile.TemporaryDirectory() as td: + temp_mbox = Path(td) / "synthetic_inbox.mbox" + temp_gt = Path(td) / "ground_truth.json" + gen_mbox_hash, gen_gt_hash = generate(temp_mbox, temp_gt, seed=SEED) + + ok = existing_mbox_hash == gen_mbox_hash and existing_gt_hash == gen_gt_hash + print(f"existing mbox sha256: {existing_mbox_hash}") + print(f"generated mbox sha256: {gen_mbox_hash}") + print(f"existing gt sha256: {existing_gt_hash}") + print(f"generated gt sha256: {gen_gt_hash}") + + if ok: + print("VERIFY OK: checked-in fixtures match deterministic generator output") + return 0 + + print("VERIFY FAILED: checked-in fixtures differ from generated output") + return 2 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--verify", action="store_true", help="Verify checked-in fixtures" + ) + parser.add_argument("--seed", type=int, default=SEED, help="Random seed") + args = parser.parse_args() + + if args.verify: + return verify() + + mbox_hash, gt_hash = generate(seed=args.seed) + print(f"Wrote: {OUT_MBOX} ({OUT_MBOX.stat().st_size} bytes)") + print(f"Wrote: {OUT_GT} ({OUT_GT.stat().st_size} bytes)") + print(f"mbox sha256: {mbox_hash}") + print(f"gt sha256: {gt_hash}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/unit/test_synthetic_mbox.py b/tests/unit/test_synthetic_mbox.py new file mode 100644 index 000000000..f3aa28b29 --- /dev/null +++ b/tests/unit/test_synthetic_mbox.py @@ -0,0 +1,198 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +import mailbox +import tempfile +from collections import Counter +from pathlib import Path + +import pytest + +from tests.fixtures.email.generate_mbox import SEED, TARGET_COUNTS, generate + +EMAIL_DIR = Path("tests/fixtures/email") +GEN_SCRIPT = EMAIL_DIR / "generate_mbox.py" + + +def _ensure_generated_to(out_mbox: Path, out_gt: Path) -> None: + if out_mbox.exists() and out_gt.exists(): + return + generate(out_mbox, out_gt) + + +@pytest.fixture(scope="module") +def mbox_obj(tmp_path_factory) -> mailbox.mbox: + td = tmp_path_factory.mktemp("email_fixtures_unit") + out_mbox = td / "synthetic_inbox.mbox" + out_gt = td / "ground_truth.json" + _ensure_generated_to(out_mbox, out_gt) + return mailbox.mbox(str(out_mbox), create=False) + + +@pytest.fixture(scope="module") +def gt(tmp_path_factory) -> dict: + td = tmp_path_factory.mktemp("email_fixtures_unit") + out_mbox = td / "synthetic_inbox.mbox" + out_gt = td / "ground_truth.json" + _ensure_generated_to(out_mbox, out_gt) + return json.loads(out_gt.read_text(encoding="utf-8")) + + +def test_fixture_files_exist(tmp_path_factory) -> None: + # Generator script should exist and be runnable; we don't require + # checked-in artifacts — generator may produce fixtures into a tmp dir. + assert GEN_SCRIPT.exists() + td = tmp_path_factory.mktemp("email_fixtures_exist") + out_mbox = td / "synthetic_inbox.mbox" + out_gt = td / "ground_truth.json" + _ensure_generated_to(out_mbox, out_gt) + assert out_mbox.exists() + assert out_gt.exists() + + +def test_mbox_under_1mb() -> None: + # Generate into a temp location and assert output size is under 1 MB + with tempfile.TemporaryDirectory() as td: + out_mbox = Path(td) / "synthetic_inbox.mbox" + out_gt = Path(td) / "ground_truth.json" + mbox_hash, gt_hash = generate(out_mbox, out_gt, seed=SEED) + assert out_mbox.stat().st_size < 1024 * 1024 + + +def test_message_count_matches_target(mbox_obj: mailbox.mbox, gt: dict) -> None: + msg_ids = [msg.get("Message-ID") for msg in mbox_obj if msg.get("Message-ID")] + assert len(msg_ids) == sum(TARGET_COUNTS.values()) + assert len(gt) == sum(TARGET_COUNTS.values()) + + +def test_category_coverage_and_counts(gt: dict) -> None: + category_counts = Counter(meta["category"] for meta in gt.values()) + assert category_counts["urgent"] >= 20 + assert category_counts["actionable"] >= 45 + assert category_counts["informational"] >= 55 + assert category_counts["low_priority"] >= 30 + + spam_count = sum(1 for meta in gt.values() if meta["is_spam"]) + phishing_count = sum(1 for meta in gt.values() if meta["is_phishing"]) + assert spam_count >= 20 + assert phishing_count >= 8 + + +def test_ambiguous_messages_have_rationale(gt: dict) -> None: + ambiguous = [meta for meta in gt.values() if meta["ambiguous"]] + assert len(ambiguous) >= 15 + assert all(a["rationale"].strip() for a in ambiguous) + + +def test_malformed_edge_cases_present(mbox_obj: mailbox.mbox) -> None: + msgs = list(mbox_obj) + missing_subject = any(msg.get("Subject") is None for msg in msgs) + missing_from = any(msg.get("From") is None for msg in msgs) + invalid_date = any(msg.get("Date") == "not-a-real-date" for msg in msgs) + long_subject = any( + (msg.get("Subject") or "").count("LongSubject-") > 30 for msg in msgs + ) + assert missing_subject + assert missing_from + assert invalid_date + assert long_subject + + +def test_threading_headers_exist(mbox_obj: mailbox.mbox) -> None: + messages = list(mbox_obj) + with_reply = [m for m in messages if m.get("In-Reply-To")] + with_refs = [m for m in messages if m.get("References")] + assert len(with_reply) >= 10 + assert len(with_refs) >= 10 + + +def test_attachments_and_multipart_coverage(mbox_obj: mailbox.mbox) -> None: + msgs = list(mbox_obj) + multipart_count = sum(1 for m in msgs if m.is_multipart()) + attachment_count = 0 + inline_count = 0 + for msg in msgs: + for part in msg.walk(): + disp = (part.get_content_disposition() or "").lower() + if disp == "attachment": + attachment_count += 1 + if disp == "inline": + inline_count += 1 + assert multipart_count >= 40 + assert attachment_count >= 20 + assert inline_count >= 4 + + +def test_persona_recurrence_range(gt: dict) -> None: + counts = Counter(meta["sender_persona"] for meta in gt.values()) + recurring = [ + "sarah_chen", + "alex_kumar", + "jordan_lee", + "it_systems", + "hr_team", + "maria_santos", + "devops_bot", + "newsletter_tech", + "newsletter_market", + ] + for key in recurring: + # Allows realistic frequency while still guaranteeing recurrence. + assert 3 <= counts[key] <= 80 + + +def test_ground_truth_required_fields(gt: dict) -> None: + required = { + "category", + "priority", + "is_thread_root", + "thread_id", + "has_attachment", + "is_spam", + "is_phishing", + "ambiguous", + "rationale", + "sender_persona", + } + for message_id, meta in gt.items(): + assert message_id.startswith("<") and message_id.endswith(">") + assert required.issubset(meta.keys()) + + +def test_generator_determinism_verify_mode() -> None: + # Generate twice into independent temp dirs and compare deterministic + # ground-truth JSON hashes and message-id sets. Raw mbox binary can + # differ due to stdlib multipart boundary generation, so avoid binary + # equality checks. + with tempfile.TemporaryDirectory() as td: + a_mbox = Path(td) / "a.mbox" + a_gt = Path(td) / "a_gt.json" + b_mbox = Path(td) / "b.mbox" + b_gt = Path(td) / "b_gt.json" + h1_mbox, h1_gt = generate(a_mbox, a_gt, seed=SEED) + h2_mbox, h2_gt = generate(b_mbox, b_gt, seed=SEED) + # Ground-truth JSON should be identical + assert h1_gt == h2_gt + # And the set of Message-IDs in each mbox should match + import mailbox as _mb + + ids1 = { + m.get("Message-ID") for m in _mb.mbox(str(a_mbox)) if m.get("Message-ID") + } + ids2 = { + m.get("Message-ID") for m in _mb.mbox(str(b_mbox)) if m.get("Message-ID") + } + assert ids1 == ids2 + + +def test_generator_accepts_seed() -> None: + # Ensure generator runs with an explicit seed and produces outputs + with tempfile.TemporaryDirectory() as td: + out_mbox = Path(td) / "synthetic_inbox.mbox" + out_gt = Path(td) / "ground_truth.json" + mbox_hash, gt_hash = generate(out_mbox, out_gt, seed=SEED) + assert out_mbox.exists() and out_gt.exists() + assert len(mbox_hash) == 64 and len(gt_hash) == 64 diff --git a/tests/unit/test_tool_loader.py b/tests/unit/test_tool_loader.py new file mode 100644 index 000000000..6c3e4298d --- /dev/null +++ b/tests/unit/test_tool_loader.py @@ -0,0 +1,474 @@ +# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +""" +Tests for the ToolLoader — bundle-based tool visibility (#688 Phase 1). + +Covers: +- Bundle registration and resolution +- Activation policies (always, session, keyword) +- Per-turn tool filtering +- Tool-use recording and warm-window behaviour +- Regression: scratchpad.query_data vs memory.recall disambiguation +""" + +import time + +import pytest + +from gaia.agents.base.tool_loader import ( + ActivationPolicy, + ToolBundle, + ToolLoader, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_registry(*tool_names: str) -> dict: + """Build a minimal fake _TOOL_REGISTRY for testing.""" + return { + name: { + "name": name, + "description": f"Tool {name}", + "parameters": {}, + "function": lambda: None, + "atomic": False, + } + for name in tool_names + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def loader(): + return ToolLoader() + + +@pytest.fixture() +def full_registry(): + """Simulates the post-#495/#606 ChatAgent registry (~27 tools).""" + return _make_registry( + # core + "read_file", + "list_files", + "run_shell_command", + # rag + "query_documents", + "query_specific_file", + "index_document", + "list_indexed_documents", + "search_indexed_chunks", + "index_directory", + # filesystem (from PR #495) + "browse_directory", + "tree", + "file_info", + "find_files", + "bookmark", + # scratchpad (from PR #495) + "create_table", + "insert_data", + "query_data", + "list_tables", + "drop_table", + # browser (from PR #495) + "fetch_page", + "search_web", + "download_file", + # memory (from PR #606) + "remember", + "recall", + "update_memory", + "forget", + "search_past_conversations", + ) + + +CORE_BUNDLE = ToolBundle( + name="core", + tools=frozenset({"read_file", "list_files", "run_shell_command"}), + policy=ActivationPolicy.ALWAYS, +) + +RAG_BUNDLE = ToolBundle( + name="rag", + tools=frozenset( + { + "query_documents", + "query_specific_file", + "index_document", + "list_indexed_documents", + "search_indexed_chunks", + "index_directory", + } + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset({r"document|pdf|index|rag|summarize"}), +) + +FILESYSTEM_BUNDLE = ToolBundle( + name="filesystem", + tools=frozenset( + {"browse_directory", "tree", "file_info", "find_files", "bookmark"} + ), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset({r"file|folder|directory|path|browse|tree"}), +) + +SCRATCHPAD_BUNDLE = ToolBundle( + name="scratchpad", + tools=frozenset( + {"create_table", "insert_data", "query_data", "list_tables", "drop_table"} + ), + policy=ActivationPolicy.SESSION, + keywords=frozenset({r"table|spreadsheet|csv|data.*entry|scratchpad"}), +) + +BROWSER_BUNDLE = ToolBundle( + name="browser", + tools=frozenset({"fetch_page", "search_web", "download_file"}), + policy=ActivationPolicy.KEYWORD, + keywords=frozenset({r"https?://|url|website|web|search.*online"}), +) + +MEMORY_BUNDLE = ToolBundle( + name="memory", + tools=frozenset( + { + "remember", + "recall", + "update_memory", + "forget", + "search_past_conversations", + } + ), + policy=ActivationPolicy.SESSION, + keywords=frozenset( + {r"remember|recall|forgot|memory|learned|last.*(week|time|session)"} + ), +) + +ALL_BUNDLES = [ + CORE_BUNDLE, + RAG_BUNDLE, + FILESYSTEM_BUNDLE, + SCRATCHPAD_BUNDLE, + BROWSER_BUNDLE, + MEMORY_BUNDLE, +] + + +# --------------------------------------------------------------------------- +# Tests — bundle registration +# --------------------------------------------------------------------------- + + +class TestBundleRegistration: + def test_register_single_bundle(self, loader): + loader.register_bundle(CORE_BUNDLE) + assert "core" in loader._bundles + + def test_register_multiple_bundles(self, loader): + loader.register_bundles(ALL_BUNDLES) + assert len(loader._bundles) == len(ALL_BUNDLES) + + def test_tool_to_bundle_index(self, loader): + loader.register_bundles(ALL_BUNDLES) + assert loader.get_bundle_for_tool("query_data") == "scratchpad" + assert loader.get_bundle_for_tool("recall") == "memory" + assert loader.get_bundle_for_tool("read_file") == "core" + + def test_overwrite_bundle(self, loader): + loader.register_bundle(CORE_BUNDLE) + new_core = ToolBundle( + name="core", + tools=frozenset({"read_file"}), + policy=ActivationPolicy.ALWAYS, + ) + loader.register_bundle(new_core) + assert loader._bundles["core"].tools == frozenset({"read_file"}) + + +# --------------------------------------------------------------------------- +# Tests — always-on policy +# --------------------------------------------------------------------------- + + +class TestAlwaysPolicy: + def test_core_always_visible(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Hi there!", full_registry) + assert "read_file" in result + assert "list_files" in result + assert "run_shell_command" in result + + def test_core_visible_regardless_of_message(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + for msg in ["", "hello", "what is 2+2?", "tell me a joke"]: + result = loader.resolve(msg, full_registry) + for tool in CORE_BUNDLE.tools: + assert tool in result, f"Core tool {tool} missing for message: {msg!r}" + + +# --------------------------------------------------------------------------- +# Tests — keyword activation +# --------------------------------------------------------------------------- + + +class TestKeywordActivation: + def test_rag_activated_by_document_keyword(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Summarize the document", full_registry) + assert "query_documents" in result + assert "index_document" in result + + def test_filesystem_activated_by_path(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Show me files in /home/user", full_registry) + assert "browse_directory" in result + assert "find_files" in result + + def test_browser_activated_by_url(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Fetch https://example.com/data", full_registry) + assert "fetch_page" in result + assert "search_web" in result + + def test_unrelated_message_hides_keyword_bundles(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("What is the capital of France?", full_registry) + # Keyword bundles should NOT be active + assert "query_documents" not in result + assert "browse_directory" not in result + assert "fetch_page" not in result + + def test_rag_stays_warm_after_first_keyword_activation(self, loader, full_registry): + """Keyword bundle that was activated should stay warm on subsequent turns.""" + loader.register_bundles(ALL_BUNDLES) + # First turn: activates RAG via keyword + loader.resolve("Summarize the document", full_registry) + # Simulate tool use + loader.record_tool_use("query_documents") + # Second turn: no keyword, but bundle should be warm + result = loader.resolve("What else does it say?", full_registry) + assert "query_documents" in result + + +# --------------------------------------------------------------------------- +# Tests — session activation +# --------------------------------------------------------------------------- + + +class TestSessionActivation: + def test_scratchpad_hidden_until_used(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("What is the weather?", full_registry) + assert "query_data" not in result + assert "create_table" not in result + + def test_scratchpad_activates_on_keyword(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Create a table for my expenses", full_registry) + assert "create_table" in result + assert "query_data" in result + + def test_scratchpad_stays_active_after_use(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + # First turn: hidden + result = loader.resolve("Hello", full_registry) + assert "query_data" not in result + # Simulate: user creates a table, tool gets executed + loader.record_tool_use("create_table") + # Second turn: scratchpad should be warm + result = loader.resolve("Now query my expenses", full_registry) + assert "query_data" in result + assert "create_table" in result + + def test_memory_hidden_until_used(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("What time is it?", full_registry) + assert "recall" not in result + assert "remember" not in result + + def test_memory_activates_on_keyword(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("What did I learn about FTS5 last week?", full_registry) + assert "recall" in result + + def test_memory_stays_warm_after_remember(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + loader.record_tool_use("remember") + result = loader.resolve("OK, what else?", full_registry) + assert "recall" in result + + +# --------------------------------------------------------------------------- +# Tests — scratchpad vs memory disambiguation (#688 collision) +# --------------------------------------------------------------------------- + + +class TestDisambiguation: + """Verify that query_data and recall are not both visible unless justified.""" + + def test_spending_query_activates_scratchpad_not_memory( + self, loader, full_registry + ): + """'What did I spend on groceries in March?' → query_data, not recall.""" + loader.register_bundles(ALL_BUNDLES) + # Pre-activate scratchpad (user created a table earlier) + loader.record_tool_use("create_table") + result = loader.resolve( + "What did I spend on groceries in March?", full_registry + ) + assert "query_data" in result + # Memory should NOT be active (no memory keywords matched) + assert "recall" not in result + + def test_learning_query_activates_memory_not_scratchpad( + self, loader, full_registry + ): + """'What did I learn about FTS5 last week?' → recall, not query_data.""" + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("What did I learn about FTS5 last week?", full_registry) + assert "recall" in result + # Scratchpad should NOT be active (never used, no keywords) + assert "query_data" not in result + + def test_both_visible_when_both_justified(self, loader, full_registry): + """Both visible if scratchpad was used AND memory keywords present.""" + loader.register_bundles(ALL_BUNDLES) + loader.record_tool_use("create_table") + result = loader.resolve( + "Do I remember anything about the data in my table from last week?", + full_registry, + ) + # Both should be active: scratchpad (session-warm), memory (keyword) + assert "query_data" in result + assert "recall" in result + + +# --------------------------------------------------------------------------- +# Tests — unbundled tools +# --------------------------------------------------------------------------- + + +class TestUnbundledTools: + def test_unbundled_tools_always_visible(self, loader): + """Tools not assigned to any bundle should appear in every turn.""" + registry = _make_registry("read_file", "custom_tool", "another_tool") + loader.register_bundle(CORE_BUNDLE) + result = loader.resolve("Hello", registry) + # read_file is in core → visible + assert "read_file" in result + # custom_tool and another_tool are unbundled → visible + assert "custom_tool" in result + assert "another_tool" in result + + +# --------------------------------------------------------------------------- +# Tests — token savings +# --------------------------------------------------------------------------- + + +class TestTokenSavings: + def test_typical_session_reduces_tool_count(self, loader, full_registry): + """A typical greeting should only expose core + unbundled tools.""" + loader.register_bundles(ALL_BUNDLES) + result = loader.resolve("Hi there, what's up?", full_registry) + # Only core tools (3) should be visible from bundled tools + assert len(result) == len(CORE_BUNDLE.tools) + + def test_max_tools_with_all_activated(self, loader, full_registry): + """Even with all bundles active, we get the full set.""" + loader.register_bundles(ALL_BUNDLES) + # Activate everything + for bundle in ALL_BUNDLES: + for tool_name in bundle.tools: + loader.record_tool_use(tool_name) + result = loader.resolve("Do everything", full_registry) + assert len(result) == len(full_registry) + + +# --------------------------------------------------------------------------- +# Tests — session reset +# --------------------------------------------------------------------------- + + +class TestSessionReset: + def test_reset_clears_activation(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + loader.record_tool_use("create_table") + # Verify activated + assert loader._state["scratchpad"].activated is True + # Reset + loader.reset_session() + assert loader._state["scratchpad"].activated is False + result = loader.resolve("Hello", full_registry) + assert "query_data" not in result + + +# --------------------------------------------------------------------------- +# Tests — warm window +# --------------------------------------------------------------------------- + + +class TestWarmWindow: + def test_recent_use_keeps_bundle_warm(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + # Manually inject a recent use into history + loader._tool_history.append(("fetch_page", time.time())) + result = loader.resolve("Tell me something", full_registry) + assert "fetch_page" in result + + def test_old_use_does_not_keep_bundle_warm(self, loader, full_registry): + loader.register_bundles(ALL_BUNDLES) + # Inject an old use (26 hours ago) + old_ts = time.time() - 26 * 3600 + loader._tool_history.append(("fetch_page", old_ts)) + result = loader.resolve("Tell me something", full_registry) + assert "fetch_page" not in result + + +# --------------------------------------------------------------------------- +# Tests — mid-conversation pivot +# --------------------------------------------------------------------------- + + +class TestMidConversationPivot: + def test_pivot_from_files_to_web(self, loader, full_registry): + """Session starts with file browsing, pivots to web research.""" + loader.register_bundles(ALL_BUNDLES) + + # Turn 1: file browsing + result1 = loader.resolve("Show me files in /projects", full_registry) + assert "browse_directory" in result1 + assert "fetch_page" not in result1 + loader.record_tool_use("browse_directory") + + # Turn 2: pivot to web + result2 = loader.resolve( + "Now search the web for React tutorials", full_registry + ) + assert "fetch_page" in result2 + # Filesystem should still be warm (was used recently) + assert "browse_directory" in result2 + + def test_pivot_from_chat_to_rag(self, loader, full_registry): + """General chat pivots to document analysis.""" + loader.register_bundles(ALL_BUNDLES) + + # Turn 1: general chat + result1 = loader.resolve("Hi, how are you?", full_registry) + assert "query_documents" not in result1 + + # Turn 2: mentions a document + result2 = loader.resolve("Summarize the quarterly report PDF", full_registry) + assert "query_documents" in result2 + assert "index_document" in result2