Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions src/gaia/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from gaia.agents.base.console import AgentConsole, SilentConsole
from gaia.agents.base.errors import format_execution_trace
from gaia.agents.base.tool_loader import ToolLoader
from gaia.agents.base.tools import _TOOL_REGISTRY

# First-party imports
Expand Down Expand Up @@ -253,6 +254,11 @@ def __init__(

# Register tools for this agent (may call rebuild_system_prompt via MCP loading;
# _response_format_template must be set above before this call).
# Bundle-based tool loading (opt-in: subclasses set self.tool_loader
# before calling super().__init__ or in _register_tools).
if not hasattr(self, "tool_loader"):
self.tool_loader: Optional[ToolLoader] = None

self._register_tools()

# Note: system_prompt is now a lazy @property that composes on first access.
Expand Down Expand Up @@ -327,7 +333,9 @@ def _get_mixin_prompts(self) -> list[str]:

return prompts

def _compose_system_prompt(self) -> str:
def _compose_system_prompt(
self, tool_registry: Optional[Dict[str, dict]] = None
) -> str:
"""
Compose final system prompt from mixin fragments + agent custom + tools + format.

Expand Down Expand Up @@ -358,7 +366,7 @@ def _compose_system_prompt(self) -> str:

# Add tool descriptions (if tools registered)
if hasattr(self, "_format_tools_for_prompt"):
tools_description = self._format_tools_for_prompt()
tools_description = self._format_tools_for_prompt(registry=tool_registry)
if tools_description:
parts.append(f"==== AVAILABLE TOOLS ====\n{tools_description}")

Expand Down Expand Up @@ -439,11 +447,23 @@ def _register_tools(self):
"""
raise NotImplementedError("Subclasses must implement _register_tools")

def _format_tools_for_prompt(self) -> str:
"""Format the registered tools into a string for the prompt."""
def _format_tools_for_prompt(
self, registry: Optional[Dict[str, dict]] = None
) -> str:
"""Format the registered tools into a string for the prompt.

Parameters
----------
registry:
If provided, use this dict of tools instead of the global
``_TOOL_REGISTRY``. ``ToolLoader.resolve()`` returns such a
filtered dict each turn.
"""
tool_descriptions = []

for name, tool_info in _TOOL_REGISTRY.items():
source = registry if registry is not None else _TOOL_REGISTRY

for name, tool_info in source.items():
params_str = ", ".join(
[
f"{param_name}{'' if param_info['required'] else '?'}: {param_info['type']}"
Expand Down Expand Up @@ -492,6 +512,20 @@ def rebuild_system_prompt(self) -> None:
# mixin prompts, tool descriptions, and response format are all included.
self._system_prompt_cache = self._compose_system_prompt()

def resolve_tools_for_turn(self, user_message: str) -> None:
"""Re-evaluate tool bundles for the current turn and rebuild the prompt.

If no ``tool_loader`` is configured this is a no-op, preserving the
existing behaviour (all registered tools always visible).

Called at the top of ``process_query`` before the first LLM call so
that keyword-activated bundles can match the current user message.
"""
if self.tool_loader is None:
return
filtered = self.tool_loader.resolve(user_message, _TOOL_REGISTRY)
self._system_prompt_cache = self._compose_system_prompt(tool_registry=filtered)

def list_tools(self, verbose: bool = True) -> None:
"""
Display all tools registered for this agent with their parameters and descriptions.
Expand Down Expand Up @@ -1393,6 +1427,12 @@ def _execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Any:
try:
result = tool(**tool_args)
logger.debug(f"Tool execution result: {result}")
# Record usage so bundle-based loader can keep the owning bundle warm
if self.tool_loader is not None:
try:
self.tool_loader.record_tool_use(tool_name)
except Exception:
logger.exception("Failed to record tool usage")
return result
except subprocess.TimeoutExpired as e:
# Handle subprocess timeout specifically
Expand Down Expand Up @@ -1846,6 +1886,10 @@ def _process_query_impl(

# Store query for error context (used in _execute_tool for error formatting)
self._current_query = user_input
# Re-evaluate which tool bundles should be visible for this turn.
# Keyword-activated bundles match against user_input; session bundles
# that were used in prior turns stay warm automatically.
self.resolve_tools_for_turn(user_input)

logger.debug(f"Processing query: {user_input}")
conversation = []
Expand Down
274 changes: 274 additions & 0 deletions src/gaia/agents/base/tool_loader.py
Comment thread
theonlychant marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""
ToolLoader — bundle-based tool visibility for agents.

Gates which tools appear in the LLM prompt each turn without changing the
global ``_TOOL_REGISTRY``. The registry remains the source of truth for
*all* registered tools; the loader picks the subset that goes into the
system prompt.

Bundles
-------
A ``ToolBundle`` groups related tools under a name with an activation
policy. Three policies exist:

* **always** — included in every prompt (e.g. ``core``).
* **session** — stays active for the rest of the session once any tool
in the bundle has been used (e.g. ``scratchpad`` after ``create_table``).
* **keyword** — activated when the current user message matches one of a
set of trigger patterns (e.g. ``browser`` on URL patterns).

The loader evaluates bundles in priority order each turn and returns the
set of tool names that should appear in the prompt.
"""

from __future__ import annotations

import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, FrozenSet, List, Optional, Set

logger = logging.getLogger(__name__)


class ActivationPolicy(Enum):
"""How a bundle decides whether to be active."""

ALWAYS = "always"
SESSION = "session" # Active once any tool in the bundle was used this session
KEYWORD = "keyword" # Active when user message matches trigger patterns


@dataclass(frozen=True)
class ToolBundle:
"""An immutable group of tools sharing an activation policy.

Parameters
----------
name:
Human-readable bundle identifier (e.g. ``"rag"``, ``"scratchpad"``).
tools:
Frozenset of tool names that belong to this bundle.
policy:
When the bundle should be included in the prompt.
keywords:
Regex patterns (case-insensitive) checked against the user message
when ``policy`` is ``KEYWORD``. Ignored for other policies.
"""

name: str
tools: FrozenSet[str]
policy: ActivationPolicy
keywords: FrozenSet[str] = frozenset()


@dataclass
class _BundleState:
"""Mutable per-session state for a single bundle."""

activated: bool = False # True once the bundle has been activated this session
last_used_ts: float = 0.0 # Timestamp of most recent tool use in this bundle


class ToolLoader:
"""Selects which registered tools appear in the LLM prompt each turn.

Usage::

loader = ToolLoader()
loader.register_bundle(ToolBundle(
name="scratchpad",
tools=frozenset({"create_table", "insert_data", "query_data",
"list_tables", "drop_table"}),
policy=ActivationPolicy.SESSION,
))

# Each turn, ask which tools should be visible:
active_tools = loader.resolve(user_message, registry)

The loader does **not** modify ``_TOOL_REGISTRY``. It returns a
filtered view that the agent uses when building the prompt.
"""

# Warm-window: if a bundle was used in the last 24 h, keep it active
WARM_WINDOW_SECS: float = 24 * 3600

def __init__(self) -> None:
self._bundles: Dict[str, ToolBundle] = {}
self._state: Dict[str, _BundleState] = {}
# History of (tool_name, timestamp) for logging / warm-window checks
self._tool_history: List[tuple[str, float]] = []
# Reverse index: tool_name → bundle_name for fast lookup
self._tool_to_bundle: Dict[str, str] = {}

# ── registration ─────────────────────────────────────────────────────

def register_bundle(self, bundle: ToolBundle) -> None:
"""Register a bundle (idempotent — overwrites if name already exists)."""
self._bundles[bundle.name] = bundle
self._state.setdefault(bundle.name, _BundleState())
for tool_name in bundle.tools:
self._tool_to_bundle[tool_name] = bundle.name

def register_bundles(self, bundles: list[ToolBundle]) -> None:
for b in bundles:
self.register_bundle(b)

# ── per-turn resolution ──────────────────────────────────────────────

def resolve(
self,
user_message: str,
registry: Dict[str, dict],
) -> Dict[str, dict]:
"""Return the subset of *registry* that should appear in the prompt.

Parameters
----------
user_message:
The current user turn (used for keyword matching).
registry:
The full ``_TOOL_REGISTRY`` dict mapping tool names → metadata.

Returns
-------
dict
Filtered copy of *registry* containing only active tools.
"""
active_names: Set[str] = set()
activated_bundles: list[str] = []

for name, bundle in self._bundles.items():
state = self._state[name]

if bundle.policy == ActivationPolicy.ALWAYS:
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

if bundle.policy == ActivationPolicy.SESSION:
if state.activated:
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Warm-window: check if any tool in the bundle was used recently
if self._was_used_recently(bundle):
state.activated = True
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Also activate if keywords match (session bundles can have keywords)
if bundle.keywords and self._keywords_match(
bundle.keywords, user_message
):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

if bundle.policy == ActivationPolicy.KEYWORD:
if state.activated:
# Already activated this session — keep warm
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
if bundle.keywords and self._keywords_match(
bundle.keywords, user_message
):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Warm-window fallback
if self._was_used_recently(bundle):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

# Include any registered tools that are NOT in any bundle (backwards compat).
unbundled = {t for t in registry if t not in self._tool_to_bundle}
active_names.update(unbundled)

logger.debug(
"ToolLoader resolved %d/%d tools (bundles: %s)",
len(active_names & set(registry)),
len(registry),
", ".join(activated_bundles) or "none",
)

return {k: v for k, v in registry.items() if k in active_names}

# ── tool execution hook ──────────────────────────────────────────────

def record_tool_use(self, tool_name: str) -> None:
"""Record that a tool was executed (called from ``_execute_tool``).

This flips the owning bundle's ``activated`` flag so session-policy
bundles stay warm for the rest of the conversation.
"""
now = time.time()
self._tool_history.append((tool_name, now))
bundle_name = self._tool_to_bundle.get(tool_name)
if bundle_name and bundle_name in self._state:
self._state[bundle_name].activated = True
self._state[bundle_name].last_used_ts = now

# ── query helpers ────────────────────────────────────────────────────

def get_active_bundle_names(self) -> list[str]:
"""Return names of currently activated bundles."""
return [n for n, s in self._state.items() if s.activated]

def get_bundle_for_tool(self, tool_name: str) -> Optional[str]:
"""Return the bundle name that owns *tool_name*, or ``None``."""
return self._tool_to_bundle.get(tool_name)

def reset_session(self) -> None:
"""Clear per-session state (call between conversations)."""
for state in self._state.values():
state.activated = False
state.last_used_ts = 0.0
self._tool_history.clear()

def force_activate(self, bundle_name: str) -> None:
"""Force-activate a bundle by name.

This is the safe public API to mark a bundle as active for the
session. Callers should use this instead of touching ``_state``
directly to avoid encapsulation breaches.
"""
now = time.time()
if bundle_name in self._state:
self._state[bundle_name].activated = True
self._state[bundle_name].last_used_ts = now
logger.debug("ToolLoader: force-activated bundle '%s'", bundle_name)
else:
logger.warning(
"ToolLoader.force_activate: unknown bundle '%s'", bundle_name
)

# ── internals ────────────────────────────────────────────────────────

def _keywords_match(self, keywords: FrozenSet[str], message: str) -> bool:
"""Return True if any keyword pattern matches *message*."""
for pattern in keywords:
try:
if re.search(pattern, message, re.IGNORECASE):
return True
except re.error:
# Treat bad regex as a plain substring match
if pattern.lower() in message.lower():
return True
return False

def _was_used_recently(self, bundle: ToolBundle) -> bool:
"""Check if any tool in *bundle* was used within the warm window."""
cutoff = time.time() - self.WARM_WINDOW_SECS
for tool_name, ts in reversed(self._tool_history):
if ts < cutoff:
break
if tool_name in bundle.tools:
return True
return False
Loading
Loading