Skip to content
55 changes: 50 additions & 5 deletions src/gaia/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from gaia.agents.base.console import AgentConsole, SilentConsole
from gaia.agents.base.errors import format_execution_trace
from gaia.agents.base.tool_loader import ToolLoader
from gaia.agents.base.tools import _TOOL_REGISTRY

# First-party imports
Expand Down Expand Up @@ -216,6 +217,11 @@ def __init__(
self.total_plan_steps = 0
self.plan_iterations = 0 # Track number of plan cycles

# Bundle-based tool loading (opt-in: subclasses set self.tool_loader
# before calling super().__init__ or in _register_tools).
if not hasattr(self, "tool_loader"):
self.tool_loader: Optional[ToolLoader] = None

# Initialize the console/output handler for display
# If output_handler is provided, use it; otherwise create based on silent_mode
if output_handler is not None:
Expand Down Expand Up @@ -309,10 +315,16 @@ def _get_mixin_prompts(self) -> list[str]:

return prompts

def _compose_system_prompt(self) -> str:
def _compose_system_prompt(
self, tool_registry: Optional[Dict[str, dict]] = None
) -> str:
"""
Compose final system prompt from mixin fragments + agent custom + tools + format.

Parameters:
tool_registry: If provided, use this filtered tool dict instead of the
full ``_TOOL_REGISTRY``. Passed through to ``_format_tools_for_prompt``.

Override this method for complete control over prompt composition order.

Returns:
Expand Down Expand Up @@ -340,7 +352,7 @@ def _compose_system_prompt(self) -> str:

# Add tool descriptions (if tools registered)
if hasattr(self, "_format_tools_for_prompt"):
tools_description = self._format_tools_for_prompt()
tools_description = self._format_tools_for_prompt(registry=tool_registry)
if tools_description:
parts.append(f"==== AVAILABLE TOOLS ====\n{tools_description}")

Expand Down Expand Up @@ -421,11 +433,22 @@ def _register_tools(self):
"""
raise NotImplementedError("Subclasses must implement _register_tools")

def _format_tools_for_prompt(self) -> str:
"""Format the registered tools into a string for the prompt."""
def _format_tools_for_prompt(
self, registry: Optional[Dict[str, dict]] = None
) -> str:
"""Format the registered tools into a string for the prompt.

Parameters
----------
registry:
If provided, use this dict of tools instead of the global
``_TOOL_REGISTRY``. ``ToolLoader.resolve()`` returns such a
filtered dict each turn.
"""
source = registry if registry is not None else _TOOL_REGISTRY
tool_descriptions = []

for name, tool_info in _TOOL_REGISTRY.items():
for name, tool_info in source.items():
params_str = ", ".join(
[
f"{param_name}{'' if param_info['required'] else '?'}: {param_info['type']}"
Expand Down Expand Up @@ -474,6 +497,20 @@ def rebuild_system_prompt(self) -> None:
# mixin prompts, tool descriptions, and response format are all included.
self._system_prompt_cache = self._compose_system_prompt()

def resolve_tools_for_turn(self, user_message: str) -> None:
"""Re-evaluate tool bundles for the current turn and rebuild the prompt.

If no ``tool_loader`` is configured this is a no-op, preserving the
existing behaviour (all registered tools always visible).

Called at the top of ``process_query`` before the first LLM call so
that keyword-activated bundles can match the current user message.
"""
if self.tool_loader is None:
return
filtered = self.tool_loader.resolve(user_message, _TOOL_REGISTRY)
self._system_prompt_cache = self._compose_system_prompt(tool_registry=filtered)

def list_tools(self, verbose: bool = True) -> None:
"""
Display all tools registered for this agent with their parameters and descriptions.
Expand Down Expand Up @@ -1351,6 +1388,9 @@ def _execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> Any:
try:
result = tool(**tool_args)
logger.debug(f"Tool execution result: {result}")
# Record usage so bundle-based loader can keep the owning bundle warm
if self.tool_loader is not None:
self.tool_loader.record_tool_use(tool_name)
return result
except subprocess.TimeoutExpired as e:
# Handle subprocess timeout specifically
Expand Down Expand Up @@ -1687,6 +1727,11 @@ def process_query(
# Store query for error context (used in _execute_tool for error formatting)
self._current_query = user_input

# Re-evaluate which tool bundles should be visible for this turn.
# Keyword-activated bundles match against user_input; session bundles
# that were used in prior turns stay warm automatically.
self.resolve_tools_for_turn(user_input)

logger.debug(f"Processing query: {user_input}")
conversation = []
# Build messages array for chat completions
Expand Down
257 changes: 257 additions & 0 deletions src/gaia/agents/base/tool_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""
ToolLoader — bundle-based tool visibility for agents.

Gates which tools appear in the LLM prompt each turn without changing the
global ``_TOOL_REGISTRY``. The registry remains the source of truth for
*all* registered tools; the loader picks the subset that goes into the
system prompt.

Bundles
-------
A ``ToolBundle`` groups related tools under a name with an activation
policy. Three policies exist:

* **always** — included in every prompt (e.g. ``core``).
* **session** — stays active for the rest of the session once any tool
in the bundle has been used (e.g. ``scratchpad`` after ``create_table``).
* **keyword** — activated when the current user message matches one of a
set of trigger patterns (e.g. ``browser`` on URL patterns).

The loader evaluates bundles in priority order each turn and returns the
set of tool names that should appear in the prompt.
"""

from __future__ import annotations

import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, FrozenSet, List, Optional, Set

logger = logging.getLogger(__name__)


class ActivationPolicy(Enum):
"""How a bundle decides whether to be active."""

ALWAYS = "always"
SESSION = "session" # Active once any tool in the bundle was used this session
KEYWORD = "keyword" # Active when user message matches trigger patterns


@dataclass(frozen=True)
class ToolBundle:
"""An immutable group of tools sharing an activation policy.

Parameters
----------
name:
Human-readable bundle identifier (e.g. ``"rag"``, ``"scratchpad"``).
tools:
Frozenset of tool names that belong to this bundle.
policy:
When the bundle should be included in the prompt.
keywords:
Regex patterns (case-insensitive) checked against the user message
when ``policy`` is ``KEYWORD``. Ignored for other policies.
"""

name: str
tools: FrozenSet[str]
policy: ActivationPolicy
keywords: FrozenSet[str] = frozenset()


@dataclass
class _BundleState:
"""Mutable per-session state for a single bundle."""

activated: bool = False # True once the bundle has been activated this session
last_used_ts: float = 0.0 # Timestamp of most recent tool use in this bundle


class ToolLoader:
"""Selects which registered tools appear in the LLM prompt each turn.

Usage::

loader = ToolLoader()
loader.register_bundle(ToolBundle(
name="scratchpad",
tools=frozenset({"create_table", "insert_data", "query_data",
"list_tables", "drop_table"}),
policy=ActivationPolicy.SESSION,
))

# Each turn, ask which tools should be visible:
active_tools = loader.resolve(user_message, registry)

The loader does **not** modify ``_TOOL_REGISTRY``. It returns a
filtered view that the agent uses when building the prompt.
"""

# Warm-window: if a bundle was used in the last 24 h, keep it active
WARM_WINDOW_SECS: float = 24 * 3600

def __init__(self) -> None:
self._bundles: Dict[str, ToolBundle] = {}
self._state: Dict[str, _BundleState] = {}
# History of (tool_name, timestamp) for logging / warm-window checks
self._tool_history: List[tuple[str, float]] = []
# Reverse index: tool_name → bundle_name for fast lookup
self._tool_to_bundle: Dict[str, str] = {}

# ── registration ─────────────────────────────────────────────────────

def register_bundle(self, bundle: ToolBundle) -> None:
"""Register a bundle (idempotent — overwrites if name already exists)."""
self._bundles[bundle.name] = bundle
self._state.setdefault(bundle.name, _BundleState())
for tool_name in bundle.tools:
self._tool_to_bundle[tool_name] = bundle.name

def register_bundles(self, bundles: list[ToolBundle]) -> None:
for b in bundles:
self.register_bundle(b)

# ── per-turn resolution ──────────────────────────────────────────────

def resolve(
self,
user_message: str,
registry: Dict[str, dict],
) -> Dict[str, dict]:
"""Return the subset of *registry* that should appear in the prompt.

Parameters
----------
user_message:
The current user turn (used for keyword matching).
registry:
The full ``_TOOL_REGISTRY`` dict mapping tool names → metadata.

Returns
-------
dict
Filtered copy of *registry* containing only active tools.
"""
active_names: Set[str] = set()
activated_bundles: list[str] = []

for name, bundle in self._bundles.items():
state = self._state[name]

if bundle.policy == ActivationPolicy.ALWAYS:
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

if bundle.policy == ActivationPolicy.SESSION:
if state.activated:
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Warm-window: check if any tool in the bundle was used recently
if self._was_used_recently(bundle):
state.activated = True
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Also activate if keywords match (session bundles can have keywords)
if bundle.keywords and self._keywords_match(
bundle.keywords, user_message
):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

if bundle.policy == ActivationPolicy.KEYWORD:
if state.activated:
# Already activated this session — keep warm
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
if bundle.keywords and self._keywords_match(
bundle.keywords, user_message
):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue
# Warm-window fallback
if self._was_used_recently(bundle):
active_names.update(bundle.tools)
activated_bundles.append(name)
continue

# Include any registered tools that are NOT in any bundle (backwards compat).
unbundled = {t for t in registry if t not in self._tool_to_bundle}
active_names.update(unbundled)

logger.debug(
"ToolLoader resolved %d/%d tools (bundles: %s)",
len(active_names & set(registry)),
len(registry),
", ".join(activated_bundles) or "none",
)

return {k: v for k, v in registry.items() if k in active_names}

# ── tool execution hook ──────────────────────────────────────────────

def record_tool_use(self, tool_name: str) -> None:
"""Record that a tool was executed (called from ``_execute_tool``).

This flips the owning bundle's ``activated`` flag so session-policy
bundles stay warm for the rest of the conversation.
"""
now = time.time()
self._tool_history.append((tool_name, now))
bundle_name = self._tool_to_bundle.get(tool_name)
if bundle_name and bundle_name in self._state:
self._state[bundle_name].activated = True
self._state[bundle_name].last_used_ts = now

# ── query helpers ────────────────────────────────────────────────────

def get_active_bundle_names(self) -> list[str]:
"""Return names of currently activated bundles."""
return [n for n, s in self._state.items() if s.activated]

def get_bundle_for_tool(self, tool_name: str) -> Optional[str]:
"""Return the bundle name that owns *tool_name*, or ``None``."""
return self._tool_to_bundle.get(tool_name)

def reset_session(self) -> None:
"""Clear per-session state (call between conversations)."""
for state in self._state.values():
state.activated = False
state.last_used_ts = 0.0
self._tool_history.clear()

# ── internals ────────────────────────────────────────────────────────

def _keywords_match(self, keywords: FrozenSet[str], message: str) -> bool:
"""Return True if any keyword pattern matches *message*."""
for pattern in keywords:
try:
if re.search(pattern, message, re.IGNORECASE):
return True
except re.error:
# Treat bad regex as a plain substring match
if pattern.lower() in message.lower():
return True
return False

def _was_used_recently(self, bundle: ToolBundle) -> bool:
"""Check if any tool in *bundle* was used within the warm window."""
cutoff = time.time() - self.WARM_WINDOW_SECS
for tool_name, ts in reversed(self._tool_history):
if ts < cutoff:
break
if tool_name in bundle.tools:
return True
return False
Loading
Loading