Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 115 additions & 8 deletions src/gaia/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,49 @@
}


# Issue #1023: smaller LLMs (Gemma-4-E4B-class) sometimes emit Windows paths
# in tool-call arguments with single backslashes (e.g. ``C:\Users\Klaus``)
# where strict JSON requires the backslashes doubled. ``json.loads`` rejects
# any backslash NOT followed by one of: " \ / b f n r t u -- so ``\U`` (and
# the other dozen-or-so cases) errors with "Invalid \escape: ...".
# ``_repair_invalid_json_escapes`` doubles only the offending backslashes,
# leaving valid escapes (\n, \t, \uXXXX, etc.) untouched. Idempotent.
#
# The pattern matches a backslash followed by one character (DOTALL so the
# next char can also be a newline). The replacement callback decides whether
# to keep the pair (valid escape) or double the leading backslash (invalid).
# Consuming the pair in a single match -- rather than just matching the lone
# backslash -- keeps the function idempotent: running it twice on ``\\X``
# does not produce ``\\\\X``.
_JSON_BACKSLASH_PAIR_RE = re.compile(r"\\(.)", re.DOTALL)
_VALID_JSON_ESCAPE_NEXT_CHARS = frozenset('"\\/bfnrtu')


def _repair_invalid_json_escapes(s: str) -> str:
"""Double any backslash that strict JSON would reject as an invalid escape.

Used as a one-shot recovery pass when ``json.loads`` raises on tool-call
arguments emitted by under-escaping LLMs. Already-valid JSON is returned
unchanged (idempotent).
"""

def _fix(match: "re.Match[str]") -> str:
nxt = match.group(1)
if nxt in _VALID_JSON_ESCAPE_NEXT_CHARS:
return match.group(0)
return "\\\\" + nxt

return _JSON_BACKSLASH_PAIR_RE.sub(_fix, s)


# Capability tools whose post-call outcome (success / error) gates the
# verbose-failure override in ``process_query``. Currently a single-element
# tuple, but kept as a tuple to mirror the prefix-match style used elsewhere
# (``startswith``) and so future capability tools (e.g. ``generate_video``)
# slot in without restructuring the guard logic.
_SD_CAPABILITY_TOOLS = ("generate_image",)


class Agent(abc.ABC):
"""
Base Agent class that provides core functionality for domain-specific agents.
Expand Down Expand Up @@ -1041,13 +1084,36 @@ def _parse_llm_response(self, response: str) -> Dict[str, Any]:
elif isinstance(arguments_raw, dict):
tool_args = arguments_raw
elif isinstance(arguments_raw, (str, bytes, bytearray)):
args_str = (
arguments_raw.decode("utf-8")
if isinstance(arguments_raw, (bytes, bytearray))
else arguments_raw
)
try:
tool_args = json.loads(arguments_raw)
tool_args = json.loads(args_str)
except json.JSONDecodeError as exc:
raise ValueError(
f"Malformed tool_call arguments for '{name}': {exc}. "
f"Raw arguments: {str(arguments_raw)[:200]}"
) from exc
# Issue #1023: Windows paths emitted with single
# backslashes (``C:\Users\Klaus``) -> ``\U`` is invalid
# JSON. Repair invalid escapes and retry once before
# surfacing the error to the recovery layer.
repaired = _repair_invalid_json_escapes(args_str)
if repaired == args_str:
raise ValueError(
f"Malformed tool_call arguments for '{name}': {exc}. "
f"Raw arguments: {args_str[:200]}"
) from exc
try:
tool_args = json.loads(repaired)
except json.JSONDecodeError as exc2:
raise ValueError(
f"Malformed tool_call arguments for '{name}': {exc2}. "
f"Raw arguments: {args_str[:200]}"
) from exc2
logger.debug(
"[PARSE] repaired invalid backslash escape(s) in "
"tool_call args for '%s'",
name,
)
else:
# Unexpected shape (list / int / None-ish) — treat as malformed
# so the recovery layer in process_query nudges the model to
Expand Down Expand Up @@ -1903,6 +1969,12 @@ def _process_query_impl(
tool_call_log = (
[]
) # Full unbounded log of all tool calls this turn (for workflow guards)
# Issue #1023: track the latest outcome of any capability tool
# (currently ``generate_image``) so the verbose-failure override
# downstream fires only when the tool actually errored. ``None``
# = not called yet, ``True`` = last call succeeded, ``False`` =
# last call returned an error.
capability_tool_last_succeeded: Optional[bool] = None
query_result_cache: dict[str, int] = (
{}
) # result_hash → call count (result-based dedup)
Expand Down Expand Up @@ -2015,6 +2087,20 @@ def _process_query_impl(
# Stop progress indicator
self.console.stop_progress()

# Issue #1023: record success/failure of capability tools
# so the verbose-failure override downstream can fire
# only when the tool actually errored. ``.lower()``
# mirrors the defensive check at
# ``has_tried_capability_tool`` so a model that emits
# ``Generate_Image`` doesn't slip past the tracker.
if any(
tool_name.lower().startswith(_s) for _s in _SD_CAPABILITY_TOOLS
):
capability_tool_last_succeeded = not (
isinstance(tool_result, dict)
and tool_result.get("status") in ("error", "denied")
)

# Handle domain-specific post-processing
self._post_process_tool_result(tool_name, tool_args, tool_result)

Expand Down Expand Up @@ -2904,6 +2990,17 @@ def _process_query_impl(
# Stop progress indicator
self.console.stop_progress()

# Issue #1023: record success/failure of capability tools so
# the verbose-failure override downstream fires only when the
# tool actually errored. ``.lower()`` mirrors the defensive
# check at ``has_tried_capability_tool`` so a model that emits
# ``Generate_Image`` doesn't slip past the tracker.
if any(tool_name.lower().startswith(_s) for _s in _SD_CAPABILITY_TOOLS):
capability_tool_last_succeeded = not (
isinstance(tool_result, dict)
and tool_result.get("status") in ("error", "denied")
)

# Result-based dedup: if this tool (query family) returns the same result
# it returned in a prior call, inject a correction so the agent stops looping.
_QUERY_TOOLS = (
Expand Down Expand Up @@ -3271,9 +3368,8 @@ def _process_query_impl(
r"i can.*create.*image",
r"when.*--sd",
]
_SD_TOOLS = ("generate_image",)
has_tried_capability_tool = any(
any(_tname.lower().startswith(_s) for _s in _SD_TOOLS)
any(_tname.lower().startswith(_s) for _s in _SD_CAPABILITY_TOOLS)
for _tname, _ in tool_call_log
)
is_capability_claim = any(
Expand Down Expand Up @@ -3360,7 +3456,18 @@ def _process_query_impl(
# Post-failure verbosity guard: when generate_image was called and
# failed, the LLM often apologises and explains "what it would have done"
# with prompt-engineering tips. Intercept and replace with a clean response.
if has_tried_capability_tool:
#
# Issue #1023: gate on the LATEST outcome of the capability tool.
# When generate_image succeeded and a *different* tool's parse
# error provoked a verbose apology, the override used to clobber
# the model's reply with a misleading "Image generation is not
# available" message even though the image was generated. Now
# the override fires only when the most recent capability call
# actually returned an error.
if (
has_tried_capability_tool
and capability_tool_last_succeeded is False
):
_SD_POST_FAILURE_VERBOSE = [
r"would have done",
r"what i would",
Expand Down
Loading
Loading