Skip to content

Commit be35a50

Browse files
committed
max retry for auto toolcalls
1 parent dcb5e57 commit be35a50

File tree

6 files changed

+203
-14
lines changed

6 files changed

+203
-14
lines changed

sdk/ai/azure-ai-projects/azure/ai/projects/aio/operations/_patch.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import logging
1414
import os
1515
import time
16+
import json
1617
from pathlib import Path
1718
from typing import (
1819
IO,
@@ -664,6 +665,7 @@ class AgentsOperations(AgentsOperationsGenerated):
664665
def __init__(self, *args, **kwargs) -> None:
665666
super().__init__(*args, **kwargs)
666667
self._function_tool = _models.AsyncFunctionTool(set())
668+
self._function_tool_max_retry = 10
667669

668670
# pylint: disable=arguments-differ
669671
@overload
@@ -1622,6 +1624,7 @@ async def create_and_process_run(
16221624
)
16231625

16241626
# Monitor and process the run status
1627+
current_retry = 0
16251628
while run.status in [
16261629
RunStatus.QUEUED,
16271630
RunStatus.IN_PROGRESS,
@@ -1643,6 +1646,17 @@ async def create_and_process_run(
16431646
toolset.add(self._function_tool)
16441647
tool_outputs = await toolset.execute_tool_calls(tool_calls)
16451648

1649+
if self.has_errors_in_toolcalls_output(tool_outputs):
1650+
if current_retry >= self._function_tool_max_retry:
1651+
logging.warning(
1652+
f"Tool outputs contain errors - reaching max retry {self._function_tool_max_retry}"
1653+
)
1654+
self.cancel_run(thread_id=thread_id, run_id=run.id)
1655+
break
1656+
else:
1657+
logging.warning(f"Tool outputs contain errors - retrying")
1658+
current_retry += 1
1659+
16461660
logging.info("Tool outputs: %s", tool_outputs)
16471661
if tool_outputs:
16481662
await self.submit_tool_outputs_to_run(
@@ -1653,6 +1667,25 @@ async def create_and_process_run(
16531667

16541668
return run
16551669

1670+
def has_errors_in_toolcalls_output(self, tool_outputs: List[Dict]) -> bool:
1671+
"""
1672+
Check if any tool output contains an error.
1673+
1674+
:param List[Dict] tool_outputs: A list of tool outputs to check.
1675+
:return: True if any output contains an error, False otherwise.
1676+
:rtype: bool
1677+
"""
1678+
for tool_output in tool_outputs:
1679+
output = tool_output.get("output")
1680+
if isinstance(output, str):
1681+
try:
1682+
output_json = json.loads(output)
1683+
if "error" in output_json:
1684+
return True
1685+
except json.JSONDecodeError:
1686+
continue
1687+
return False
1688+
16561689
@overload
16571690
async def create_stream(
16581691
self,
@@ -3128,25 +3161,31 @@ async def delete_agent(self, agent_id: str, **kwargs: Any) -> _models.AgentDelet
31283161
return await super().delete_agent(agent_id, **kwargs)
31293162

31303163
@overload
3131-
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]]) -> None:
3164+
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]], max_retry: int = 10) -> None:
31323165
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31333166
If this is not set, functions must be called manually.
3167+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3168+
function call or figure out the answer with its knowledge.
31343169
:keyword functions: A set of callable functions to be used as tools.
31353170
:type functions: Set[Callable[..., Any]]
31363171
"""
31373172

31383173
@overload
3139-
def enable_auto_function_calls(self, *, function_tool: _models.AsyncFunctionTool) -> None:
3174+
def enable_auto_function_calls(self, *, function_tool: _models.AsyncFunctionTool, max_retry: int = 10) -> None:
31403175
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31413176
If this is not set, functions must be called manually.
3177+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3178+
function call or figure out the answer with its knowledge.
31423179
:keyword function_tool: An AsyncFunctionTool object representing the tool to be used.
31433180
:type function_tool: Optional[_models.AsyncFunctionTool]
31443181
"""
31453182

31463183
@overload
3147-
def enable_auto_function_calls(self, *, toolset: _models.AsyncToolSet) -> None:
3184+
def enable_auto_function_calls(self, *, toolset: _models.AsyncToolSet, max_retry: int = 10) -> None:
31483185
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31493186
If this is not set, functions must be called manually.
3187+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3188+
function call or figure out the answer with its knowledge.
31503189
:keyword toolset: An AsyncToolSet object representing the set of tools to be used.
31513190
:type toolset: Optional[_models.AsyncToolSet]
31523191
"""
@@ -3157,9 +3196,12 @@ def enable_auto_function_calls(
31573196
functions: Optional[Set[Callable[..., Any]]] = None,
31583197
function_tool: Optional[_models.AsyncFunctionTool] = None,
31593198
toolset: Optional[_models.AsyncToolSet] = None,
3199+
max_retry: int = 10,
31603200
) -> None:
31613201
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31623202
If this is not set, functions must be called manually.
3203+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3204+
function call or figure out the answer with its knowledge.
31633205
:keyword functions: A set of callable functions to be used as tools.
31643206
:type functions: Set[Callable[..., Any]]
31653207
:keyword function_tool: An AsyncFunctionTool object representing the tool to be used.
@@ -3175,6 +3217,8 @@ def enable_auto_function_calls(
31753217
tool = toolset.get_tool(_models.AsyncFunctionTool)
31763218
self._function_tool = tool
31773219

3220+
self._function_tool_max_retry = max_retry
3221+
31783222

31793223
class _SyncCredentialWrapper(TokenCredential):
31803224
"""

sdk/ai/azure-ai-projects/azure/ai/projects/operations/_patch.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import sys
1414
import time
15+
import json
1516
from pathlib import Path
1617
from typing import (
1718
IO,
@@ -844,6 +845,7 @@ class AgentsOperations(AgentsOperationsGenerated):
844845
def __init__(self, *args, **kwargs) -> None:
845846
super().__init__(*args, **kwargs)
846847
self._function_tool = _models.FunctionTool(set())
848+
self._function_tool_max_retry = 10
847849

848850
# pylint: disable=arguments-differ
849851
@overload
@@ -1803,6 +1805,7 @@ def create_and_process_run(
18031805
)
18041806

18051807
# Monitor and process the run status
1808+
current_retry = 0
18061809
while run.status in [
18071810
RunStatus.QUEUED,
18081811
RunStatus.IN_PROGRESS,
@@ -1826,14 +1829,47 @@ def create_and_process_run(
18261829
toolset.add(self._function_tool)
18271830
tool_outputs = toolset.execute_tool_calls(tool_calls)
18281831

1832+
if self.has_errors_in_toolcalls_output(tool_outputs):
1833+
if current_retry >= self._function_tool_max_retry:
1834+
logging.warning(
1835+
f"Tool outputs contain errors - reaching max retry {self._function_tool_max_retry}"
1836+
)
1837+
self.cancel_run(thread_id=thread_id, run_id=run.id)
1838+
break
1839+
else:
1840+
logging.warning(f"Tool outputs contain errors - retrying")
1841+
current_retry += 1
1842+
18291843
logging.info("Tool outputs: %s", tool_outputs)
18301844
if tool_outputs:
1831-
self.submit_tool_outputs_to_run(thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs)
1845+
run2 = self.submit_tool_outputs_to_run(
1846+
thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs
1847+
)
1848+
logging.info("Tool outputs submitted to run: %s", run2.id)
18321849

18331850
logging.info("Current run status: %s", run.status)
18341851

18351852
return run
18361853

1854+
def has_errors_in_toolcalls_output(self, tool_outputs: List[Dict]) -> bool:
1855+
"""
1856+
Check if any tool output contains an error.
1857+
1858+
:param List[Dict] tool_outputs: A list of tool outputs to check.
1859+
:return: True if any output contains an error, False otherwise.
1860+
:rtype: bool
1861+
"""
1862+
for tool_output in tool_outputs:
1863+
output = tool_output.get("output")
1864+
if isinstance(output, str):
1865+
try:
1866+
output_json = json.loads(output)
1867+
if "error" in output_json:
1868+
return True
1869+
except json.JSONDecodeError:
1870+
continue
1871+
return False
1872+
18371873
@overload
18381874
def create_stream(
18391875
self,
@@ -3309,27 +3345,39 @@ def delete_agent(self, agent_id: str, **kwargs: Any) -> _models.AgentDeletionSta
33093345
return super().delete_agent(agent_id, **kwargs)
33103346

33113347
@overload
3312-
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]]) -> None:
3348+
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]], max_retry: int = 10) -> None:
33133349
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
33143350
If this is not set, functions must be called manually.
3351+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3352+
function call or figure out the answer with its knowledge.
33153353
:keyword functions: A set of callable functions to be used as tools.
33163354
:type functions: Set[Callable[..., Any]]
3355+
:keyword max_retry: Maximum number of errors allowed and retry per run or stream. Default value is 10.
3356+
:type max_retry: int
33173357
"""
33183358

33193359
@overload
3320-
def enable_auto_function_calls(self, *, function_tool: _models.FunctionTool) -> None:
3360+
def enable_auto_function_calls(self, *, function_tool: _models.FunctionTool, max_retry: int = 10) -> None:
33213361
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
33223362
If this is not set, functions must be called manually.
3363+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3364+
function call or figure out the answer with its knowledge.
33233365
:keyword function_tool: A FunctionTool object representing the tool to be used.
33243366
:type function_tool: Optional[_models.FunctionTool]
3367+
:keyword max_retry: Maximum number of errors allowed and retry per run or stream. Default value is 10.
3368+
:type max_retry: int
33253369
"""
33263370

33273371
@overload
3328-
def enable_auto_function_calls(self, *, toolset: _models.ToolSet) -> None:
3372+
def enable_auto_function_calls(self, *, toolset: _models.ToolSet, max_retry: int = 10) -> None:
33293373
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
33303374
If this is not set, functions must be called manually.
3375+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3376+
function call or figure out the answer with its knowledge.
33313377
:keyword toolset: A ToolSet object representing the set of tools to be used.
33323378
:type toolset: Optional[_models.ToolSet]
3379+
:keyword max_retry: Maximum number of errors allowed and retry per run or stream. Default value is 10.
3380+
:type max_retry: int
33333381
"""
33343382

33353383
@distributed_trace
@@ -3339,15 +3387,20 @@ def enable_auto_function_calls(
33393387
functions: Optional[Set[Callable[..., Any]]] = None,
33403388
function_tool: Optional[_models.FunctionTool] = None,
33413389
toolset: Optional[_models.ToolSet] = None,
3390+
max_retry: int = 10,
33423391
) -> None:
33433392
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
33443393
If this is not set, functions must be called manually.
3394+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3395+
function call or figure out the answer with its knowledge.
33453396
:keyword functions: A set of callable functions to be used as tools.
33463397
:type functions: Set[Callable[..., Any]]
33473398
:keyword function_tool: A FunctionTool object representing the tool to be used.
33483399
:type function_tool: Optional[_models.FunctionTool]
33493400
:keyword toolset: A ToolSet object representing the set of tools to be used.
33503401
:type toolset: Optional[_models.ToolSet]
3402+
:keyword max_retry: Maximum number of errors allowed and retry per run or stream. Default value is 10.
3403+
:type max_retry: int
33513404
"""
33523405
if functions:
33533406
self._function_tool = _models.FunctionTool(functions)
@@ -3357,6 +3410,8 @@ def enable_auto_function_calls(
33573410
tool = toolset.get_tool(_models.FunctionTool)
33583411
self._function_tool = tool
33593412

3413+
self._function_tool_max_retry = max_retry
3414+
33603415

33613416
__all__: List[str] = [
33623417
"AgentsOperations",

sdk/ai/azure-ai-projects/samples/agents/sample_agents_run_with_toolset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
toolset.add(code_interpreter)
4646

4747
# To enable tool calls executed automatically
48-
project_client.agents.enable_auto_function_calls(toolset=toolset)
48+
project_client.agents.enable_auto_function_calls(toolset=toolset, max_retry=3)
4949

5050
agent = project_client.agents.create_agent(
5151
model=os.environ["MODEL_DEPLOYMENT_NAME"],

sdk/ai/azure-ai-projects/samples/agents/user_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def fetch_current_datetime(format: Optional[str] = None) -> str:
2828
time_format = "%Y-%m-%d %H:%M:%S"
2929

3030
time_json = json.dumps({"current_time": current_time.strftime(time_format)})
31+
raise ValueError("Just a minute")
3132
return time_json
3233

3334

sdk/ai/azure-ai-projects/tests/agents/test_agent_operations.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def function2():
5252
return "output from the second agent"
5353

5454

55+
def function_throw_exception():
56+
raise ValueError("Just a minute")
57+
58+
5559
class TestAgentsOperations:
5660
"""Tests for agent operations"""
5761

@@ -217,16 +221,17 @@ def _set_toolcalls(
217221
self, project_client: AgentsOperations, toolset1: Optional[ToolSet], toolset2: Optional[ToolSet]
218222
) -> None:
219223
"""Set the tool calls for the agent."""
224+
max_retry = 3
220225
if toolset1 and toolset2:
221226
function_in_toolset1 = set(toolset1.get_tool(tool_type=FunctionTool)._functions.values())
222227
function_in_toolset2 = set(toolset2.get_tool(tool_type=FunctionTool)._functions.values())
223228
function_tool = FunctionTool(function_in_toolset1)
224229
function_tool.add_functions(function_in_toolset2)
225-
project_client.enable_auto_function_calls(function_tool=function_tool)
230+
project_client.enable_auto_function_calls(function_tool=function_tool, max_retry=max_retry)
226231
elif toolset1:
227-
project_client.enable_auto_function_calls(toolset=toolset1)
232+
project_client.enable_auto_function_calls(toolset=toolset1, max_retry=max_retry)
228233
elif toolset2:
229-
project_client.enable_auto_function_calls(toolset=toolset2)
234+
project_client.enable_auto_function_calls(toolset=toolset2, max_retry=max_retry)
230235

231236
@patch("azure.ai.projects._patch.PipelineClient")
232237
@pytest.mark.parametrize(
@@ -296,6 +301,45 @@ def test_multiple_agents_create(
296301
project_client.agents.delete_agent(agent1.id)
297302
project_client.agents.delete_agent(agent2.id)
298303

304+
@patch("azure.ai.projects.operations._operations.AgentsOperations.cancel_run")
305+
@patch("azure.ai.projects._patch.PipelineClient")
306+
def test_auto_function_calls_retry(
307+
self,
308+
mock_pipeline_client_gen: MagicMock,
309+
mock_cancel_run: MagicMock,
310+
) -> None:
311+
"""Test azure function with toolset."""
312+
toolset = self.get_toolset("file_for_agent1", function_throw_exception)
313+
mock_response = MagicMock()
314+
mock_response.status_code = 200
315+
mock_response.json.side_effect = [
316+
self._get_agent_json("first", "123", toolset),
317+
self._get_run("run1", toolset), # create_run
318+
self._get_run("run2", toolset), # get_run
319+
self._get_run("run3", toolset), # get_run
320+
self._get_run("run4", toolset), # get_run
321+
self._get_run("run5", toolset), # get_run
322+
]
323+
mock_pipeline_response = MagicMock()
324+
mock_pipeline_response.http_response = mock_response
325+
mock_pipeline = MagicMock()
326+
mock_pipeline._pipeline.run.return_value = mock_pipeline_response
327+
mock_pipeline_client_gen.return_value = mock_pipeline
328+
project_client = self.get_mock_client()
329+
with project_client:
330+
# Check that pipelines are created as expected.
331+
self._set_toolcalls(project_client.agents, toolset, None)
332+
agent1 = project_client.agents.create_agent(
333+
model="gpt-4-1106-preview",
334+
name="first",
335+
instructions="You are a helpful assistant",
336+
toolset=toolset,
337+
)
338+
# Create run with new tool set, which also can be none.
339+
project_client.agents.create_and_process_run(thread_id="some_thread_id", agent_id=agent1.id)
340+
assert mock_cancel_run.call_count == 1
341+
assert project_client.agents.submit_tool_outputs_to_run.call_count == 3
342+
299343
@patch("azure.ai.projects._patch.PipelineClient")
300344
@pytest.mark.parametrize(
301345
"file_agent_1,file_agent_2",

0 commit comments

Comments
 (0)