Skip to content

Commit 32db67c

Browse files
theonlychantCopilot
andauthored
feat(email): synthetic .mbox dataset for email triage tests (#928)
Closes #848 ## Summary Adds a synthetic .mbox dataset for testing the email triage agent. The fixtures provide realistic email threads for unit and integration testing without requiring live mailbox access. ## Why GAIA needs it The email triage agent currently has no test data to run against, making it impossible to validate triage logic in CI. ## Test plan - `tests/unit/test_agents_split.py` - all tests passing --------- Co-authored-by: Copilot <copilot@github.com>
1 parent cbcc95d commit 32db67c

8 files changed

Lines changed: 1677 additions & 47 deletions

File tree

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
"""
4+
ToolLoader — bundle-based tool visibility for agents.
5+
6+
Gates which tools appear in the LLM prompt each turn without changing the
7+
global ``_TOOL_REGISTRY``. The registry remains the source of truth for
8+
*all* registered tools; the loader picks the subset that goes into the
9+
system prompt.
10+
11+
Bundles
12+
-------
13+
A ``ToolBundle`` groups related tools under a name with an activation
14+
policy. Three policies exist:
15+
16+
* **always** — included in every prompt (e.g. ``core``).
17+
* **session** — stays active for the rest of the session once any tool
18+
in the bundle has been used (e.g. ``scratchpad`` after ``create_table``).
19+
* **keyword** — activated when the current user message matches one of a
20+
set of trigger patterns (e.g. ``browser`` on URL patterns).
21+
22+
The loader evaluates bundles in priority order each turn and returns the
23+
set of tool names that should appear in the prompt.
24+
"""
25+
26+
from __future__ import annotations
27+
28+
import logging
29+
import re
30+
import time
31+
from dataclasses import dataclass
32+
from enum import Enum
33+
from typing import Dict, FrozenSet, List, Optional, Set
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
class ActivationPolicy(Enum):
39+
"""How a bundle decides whether to be active."""
40+
41+
ALWAYS = "always"
42+
SESSION = "session" # Active once any tool in the bundle was used this session
43+
KEYWORD = "keyword" # Active when user message matches trigger patterns
44+
45+
46+
@dataclass(frozen=True)
47+
class ToolBundle:
48+
"""An immutable group of tools sharing an activation policy.
49+
50+
Parameters
51+
----------
52+
name:
53+
Human-readable bundle identifier (e.g. ``"rag"``, ``"scratchpad"``).
54+
tools:
55+
Frozenset of tool names that belong to this bundle.
56+
policy:
57+
When the bundle should be included in the prompt.
58+
keywords:
59+
Regex patterns (case-insensitive) checked against the user message
60+
when ``policy`` is ``KEYWORD``. Ignored for other policies.
61+
"""
62+
63+
name: str
64+
tools: FrozenSet[str]
65+
policy: ActivationPolicy
66+
keywords: FrozenSet[str] = frozenset()
67+
68+
69+
@dataclass
70+
class _BundleState:
71+
"""Mutable per-session state for a single bundle."""
72+
73+
activated: bool = False # True once the bundle has been activated this session
74+
last_used_ts: float = 0.0 # Timestamp of most recent tool use in this bundle
75+
76+
77+
class ToolLoader:
78+
"""Selects which registered tools appear in the LLM prompt each turn.
79+
80+
Usage::
81+
82+
loader = ToolLoader()
83+
loader.register_bundle(ToolBundle(
84+
name="scratchpad",
85+
tools=frozenset({"create_table", "insert_data", "query_data",
86+
"list_tables", "drop_table"}),
87+
policy=ActivationPolicy.SESSION,
88+
))
89+
90+
# Each turn, ask which tools should be visible:
91+
active_tools = loader.resolve(user_message, registry)
92+
93+
The loader does **not** modify ``_TOOL_REGISTRY``. It returns a
94+
filtered view that the agent uses when building the prompt.
95+
"""
96+
97+
# Warm-window: if a bundle was used in the last 24 h, keep it active
98+
WARM_WINDOW_SECS: float = 24 * 3600
99+
100+
def __init__(self) -> None:
101+
self._bundles: Dict[str, ToolBundle] = {}
102+
self._state: Dict[str, _BundleState] = {}
103+
# History of (tool_name, timestamp) for logging / warm-window checks
104+
self._tool_history: List[tuple[str, float]] = []
105+
# Reverse index: tool_name → bundle_name for fast lookup
106+
self._tool_to_bundle: Dict[str, str] = {}
107+
108+
# ── registration ─────────────────────────────────────────────────────
109+
110+
def register_bundle(self, bundle: ToolBundle) -> None:
111+
"""Register a bundle (idempotent — overwrites if name already exists)."""
112+
self._bundles[bundle.name] = bundle
113+
self._state.setdefault(bundle.name, _BundleState())
114+
for tool_name in bundle.tools:
115+
self._tool_to_bundle[tool_name] = bundle.name
116+
117+
def register_bundles(self, bundles: list[ToolBundle]) -> None:
118+
for b in bundles:
119+
self.register_bundle(b)
120+
121+
# ── per-turn resolution ──────────────────────────────────────────────
122+
123+
def resolve(
124+
self,
125+
user_message: str,
126+
registry: Dict[str, dict],
127+
) -> Dict[str, dict]:
128+
"""Return the subset of *registry* that should appear in the prompt.
129+
130+
Parameters
131+
----------
132+
user_message:
133+
The current user turn (used for keyword matching).
134+
registry:
135+
The full ``_TOOL_REGISTRY`` dict mapping tool names → metadata.
136+
137+
Returns
138+
-------
139+
dict
140+
Filtered copy of *registry* containing only active tools.
141+
"""
142+
active_names: Set[str] = set()
143+
activated_bundles: list[str] = []
144+
145+
for name, bundle in self._bundles.items():
146+
state = self._state[name]
147+
148+
if bundle.policy == ActivationPolicy.ALWAYS:
149+
active_names.update(bundle.tools)
150+
activated_bundles.append(name)
151+
continue
152+
153+
if bundle.policy == ActivationPolicy.SESSION:
154+
if state.activated:
155+
active_names.update(bundle.tools)
156+
activated_bundles.append(name)
157+
continue
158+
# Warm-window: check if any tool in the bundle was used recently
159+
if self._was_used_recently(bundle):
160+
state.activated = True
161+
active_names.update(bundle.tools)
162+
activated_bundles.append(name)
163+
continue
164+
# Also activate if keywords match (session bundles can have keywords)
165+
if bundle.keywords and self._keywords_match(
166+
bundle.keywords, user_message
167+
):
168+
active_names.update(bundle.tools)
169+
activated_bundles.append(name)
170+
continue
171+
172+
if bundle.policy == ActivationPolicy.KEYWORD:
173+
if state.activated:
174+
# Already activated this session — keep warm
175+
active_names.update(bundle.tools)
176+
activated_bundles.append(name)
177+
continue
178+
if bundle.keywords and self._keywords_match(
179+
bundle.keywords, user_message
180+
):
181+
active_names.update(bundle.tools)
182+
activated_bundles.append(name)
183+
continue
184+
# Warm-window fallback
185+
if self._was_used_recently(bundle):
186+
active_names.update(bundle.tools)
187+
activated_bundles.append(name)
188+
continue
189+
190+
# Include any registered tools that are NOT in any bundle (backwards compat).
191+
unbundled = {t for t in registry if t not in self._tool_to_bundle}
192+
active_names.update(unbundled)
193+
194+
logger.debug(
195+
"ToolLoader resolved %d/%d tools (bundles: %s)",
196+
len(active_names & set(registry)),
197+
len(registry),
198+
", ".join(activated_bundles) or "none",
199+
)
200+
201+
return {k: v for k, v in registry.items() if k in active_names}
202+
203+
# ── tool execution hook ──────────────────────────────────────────────
204+
205+
def record_tool_use(self, tool_name: str) -> None:
206+
"""Record that a tool was executed (called from ``_execute_tool``).
207+
208+
This flips the owning bundle's ``activated`` flag so session-policy
209+
bundles stay warm for the rest of the conversation.
210+
"""
211+
now = time.time()
212+
self._tool_history.append((tool_name, now))
213+
bundle_name = self._tool_to_bundle.get(tool_name)
214+
if bundle_name and bundle_name in self._state:
215+
self._state[bundle_name].activated = True
216+
self._state[bundle_name].last_used_ts = now
217+
218+
# ── query helpers ────────────────────────────────────────────────────
219+
220+
def get_active_bundle_names(self) -> list[str]:
221+
"""Return names of currently activated bundles."""
222+
return [n for n, s in self._state.items() if s.activated]
223+
224+
def get_bundle_for_tool(self, tool_name: str) -> Optional[str]:
225+
"""Return the bundle name that owns *tool_name*, or ``None``."""
226+
return self._tool_to_bundle.get(tool_name)
227+
228+
def reset_session(self) -> None:
229+
"""Clear per-session state (call between conversations)."""
230+
for state in self._state.values():
231+
state.activated = False
232+
state.last_used_ts = 0.0
233+
self._tool_history.clear()
234+
235+
# ── internals ────────────────────────────────────────────────────────
236+
237+
def _keywords_match(self, keywords: FrozenSet[str], message: str) -> bool:
238+
"""Return True if any keyword pattern matches *message*."""
239+
for pattern in keywords:
240+
try:
241+
if re.search(pattern, message, re.IGNORECASE):
242+
return True
243+
except re.error:
244+
# Treat bad regex as a plain substring match
245+
if pattern.lower() in message.lower():
246+
return True
247+
return False
248+
249+
def _was_used_recently(self, bundle: ToolBundle) -> bool:
250+
"""Check if any tool in *bundle* was used within the warm window."""
251+
cutoff = time.time() - self.WARM_WINDOW_SECS
252+
for tool_name, ts in reversed(self._tool_history):
253+
if ts < cutoff:
254+
break
255+
if tool_name in bundle.tools:
256+
return True
257+
return False

src/gaia/agents/chat/agent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from gaia.agents.base.agent import Agent
2020
from gaia.agents.base.console import AgentConsole
21+
from gaia.agents.base.tool_loader import ToolLoader
2122
from gaia.agents.chat.session import SessionManager
2223
from gaia.agents.chat.tools import FileToolsMixin, RAGToolsMixin, ShellToolsMixin
2324
from gaia.agents.code.tools.file_io import FileIOToolsMixin
@@ -279,6 +280,10 @@ def __init__(self, config: Optional[ChatAgentConfig] = None):
279280
self.conversation_history: List[Dict[str, str]] = (
280281
[]
281282
) # Track conversation for persistence
283+
# Tool loader controls which tool bundles are active per-session.
284+
# Instantiate here so the agent can reset bundle activation when a
285+
# new conversation/session is created.
286+
self.tool_loader = ToolLoader()
282287

283288
# Store base URL for use in _register_tools() (VLM, etc.)
284289
self._base_url = effective_base_url
@@ -352,6 +357,14 @@ def __init__(self, config: Optional[ChatAgentConfig] = None):
352357
self.current_session = self.session_manager.create_session(
353358
config.ui_session_id
354359
)
360+
# New conversation started for this UI session; clear any
361+
# session-scoped tool activations so bundles don't persist
362+
# across distinct conversations.
363+
try:
364+
self.tool_loader.reset_session()
365+
except Exception:
366+
# Never fail agent init due to tool loader reset.
367+
pass
355368

356369
# Start watching directories
357370
if self.watch_directories:

src/gaia/agents/chat/app.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ def interactive_mode(agent: ChatAgent):
317317
agent.save_current_session()
318318
# Create new session
319319
agent.current_session = agent.session_manager.create_session()
320+
# Reset per-session tool activation (bundle state)
321+
try:
322+
if hasattr(agent, "tool_loader"):
323+
agent.tool_loader.reset_session()
324+
except Exception:
325+
pass
320326
# Clear chat history (if agent tracks it)
321327
if hasattr(agent, "chat_history"):
322328
agent.chat_history = []
@@ -997,6 +1003,12 @@ def main():
9971003
# Create initial session if not loading one
9981004
if not agent.current_session:
9991005
agent.current_session = agent.session_manager.create_session()
1006+
# Reset tool loader session state on new session
1007+
try:
1008+
if hasattr(agent, "tool_loader"):
1009+
agent.tool_loader.reset_session()
1010+
except Exception:
1011+
pass
10001012
logger.debug(f"Created new session: {agent.current_session.session_id}")
10011013

10021014
# Index document if --index flag provided

src/gaia/cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,12 @@ async def async_main(action, **kwargs):
590590
# Create initial session if not loading one
591591
if not agent.current_session:
592592
agent.current_session = agent.session_manager.create_session()
593+
# Reset tool loader session state on new session
594+
try:
595+
if hasattr(agent, "tool_loader"):
596+
agent.tool_loader.reset_session()
597+
except Exception:
598+
pass
593599
log.debug(f"Created new session: {agent.current_session.session_id}")
594600

595601
# List tools if requested
@@ -785,6 +791,12 @@ def _launch_interactive_cli(log=None):
785791

786792
if not agent.current_session:
787793
agent.current_session = agent.session_manager.create_session()
794+
# Reset tool loader session state on new session
795+
try:
796+
if hasattr(agent, "tool_loader"):
797+
agent.tool_loader.reset_session()
798+
except Exception:
799+
pass
788800

789801
interactive_mode(agent)
790802
except KeyboardInterrupt:

0 commit comments

Comments
 (0)