Skip to content

Commit 0475bbc

Browse files
committed
Allow two consecutive identical tool calls
Signed-off-by: Nikola Forró <[email protected]>
1 parent c0bb481 commit 0475bbc

File tree

6 files changed

+26
-5
lines changed

6 files changed

+26
-5
lines changed

agents/backport_agent.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,14 @@
5959
GitPreparePackageSources,
6060
)
6161
from triage_agent import BackportData, ErrorData
62-
from utils import check_subprocess, get_agent_execution_config, get_chat_model, mcp_tools, render_prompt
62+
from utils import (
63+
check_subprocess,
64+
get_agent_execution_config,
65+
get_chat_model,
66+
get_tool_call_checker_config,
67+
mcp_tools,
68+
render_prompt,
69+
)
6370
from specfile import Specfile
6471

6572
logger = logging.getLogger(__name__)
@@ -140,6 +147,7 @@ def create_backport_agent(_: list[Tool], local_tool_options: dict[str, Any]) ->
140147
return RequirementAgent(
141148
name="BackportAgent",
142149
llm=get_chat_model(),
150+
tool_call_checker=get_tool_call_checker_config(),
143151
tools=[
144152
ThinkTool(),
145153
DuckDuckGoSearchTool(),

agents/build_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ViewTool,
2323
SearchTextTool,
2424
)
25-
from utils import get_chat_model
25+
from utils import get_chat_model, get_tool_call_checker_config
2626

2727

2828
def get_instructions() -> str:
@@ -55,6 +55,7 @@ def create_build_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]
5555
return RequirementAgent(
5656
name="BuildAgent",
5757
llm=get_chat_model(),
58+
tool_call_checker=get_tool_call_checker_config(),
5859
tools=[
5960
ThinkTool(),
6061
DuckDuckGoSearchTool(),

agents/log_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ViewTool,
2424
SearchTextTool,
2525
)
26-
from utils import get_chat_model
26+
from utils import get_chat_model, get_tool_call_checker_config
2727

2828

2929
def get_instructions() -> str:
@@ -69,6 +69,7 @@ def create_log_agent(_: list[Tool], local_tool_options: dict[str, Any]) -> Requi
6969
return RequirementAgent(
7070
name="LogAgent",
7171
llm=get_chat_model(),
72+
tool_call_checker=get_tool_call_checker_config(),
7273
tools=[
7374
ThinkTool(),
7475
DuckDuckGoSearchTool(),

agents/rebase_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
SearchTextTool,
5252
)
5353
from triage_agent import RebaseData, ErrorData
54-
from utils import get_agent_execution_config, get_chat_model, mcp_tools, render_prompt
54+
from utils import get_agent_execution_config, get_chat_model, get_tool_call_checker_config, mcp_tools, render_prompt
5555

5656
logger = logging.getLogger(__name__)
5757

@@ -140,6 +140,7 @@ def create_rebase_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any
140140
return RequirementAgent(
141141
name="RebaseAgent",
142142
llm=get_chat_model(),
143+
tool_call_checker=get_tool_call_checker_config(),
143144
tools=[
144145
ThinkTool(),
145146
DuckDuckGoSearchTool(),

agents/triage_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from tools.commands import RunShellCommandTool
4545
from tools.patch_validator import PatchValidatorTool
4646
from tools.version_mapper import VersionMapperTool
47-
from utils import get_agent_execution_config, get_chat_model, mcp_tools, run_tool
47+
from utils import get_agent_execution_config, get_chat_model, get_tool_call_checker_config, mcp_tools, run_tool
4848

4949
logger = logging.getLogger(__name__)
5050

@@ -297,6 +297,7 @@ async def run_workflow(jira_issue):
297297
triage_agent = RequirementAgent(
298298
name="TriageAgent",
299299
llm=get_chat_model(),
300+
tool_call_checker=get_tool_call_checker_config(),
300301
tools=[ThinkTool(), RunShellCommandTool(), PatchValidatorTool(), VersionMapperTool()]
301302
+ [t for t in gateway_tools if t.name in ["get_jira_details", "set_jira_fields"]],
302303
memory=UnconstrainedMemory(),

agents/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import Any, AsyncGenerator, Awaitable, Callable, TypeVar, Tuple
1010

11+
from beeai_framework.agents.tool_calling.utils import ToolCallCheckerConfig
1112
from beeai_framework.backend import ChatModel, ChatModelParameters
1213
from mcp import ClientSession
1314
from mcp.client.sse import sse_client
@@ -51,6 +52,14 @@ def get_agent_execution_config() -> dict[str, int]:
5152
max_iterations=int(os.getenv("BEEAI_MAX_ITERATIONS", 255)),
5253
)
5354

55+
def get_tool_call_checker_config() -> ToolCallCheckerConfig:
56+
return ToolCallCheckerConfig(
57+
# allow two consecutive identical tool calls
58+
max_strike_length=2,
59+
max_total_occurrences=5,
60+
window_size=10,
61+
)
62+
5463

5564
def render_prompt(template: str, input: BaseModel) -> str:
5665
"""Renders a prompt template with the specified input, according to its schema."""

0 commit comments

Comments
 (0)