diff --git a/examples/async-tools.py b/examples/async-tools.py index 16e123de..9fb3d8a1 100644 --- a/examples/async-tools.py +++ b/examples/async-tools.py @@ -53,9 +53,23 @@ def subtract_two_numbers(a: int, b: int) -> int: async def main(): client = ollama.AsyncClient() + # --- Auto tool execution (max_tool_calls) --- + # When max_tool_calls is set, tools are executed automatically in a loop. + # The model calls tools, results are fed back, and the final response is returned. + print('\n--- Auto tool execution ---') + response: ChatResponse = await client.chat( + 'qwen3.5:4b', + messages=messages, + tools=[add_two_numbers, subtract_two_numbers_tool], + max_tool_calls=10 + ) + print('Response:', response.message.content) + # --- Manual tool handling --- + # Without max_tool_calls, tool calls are returned for you to handle manually. + print('\n--- Manual tool handling ---') response: ChatResponse = await client.chat( - 'llama3.1', + 'qwen3.5:4b', messages=messages, tools=[add_two_numbers, subtract_two_numbers_tool], ) @@ -79,7 +93,7 @@ async def main(): messages.append({'role': 'tool', 'content': str(output), 'tool_name': tool.function.name}) # Get final response from model with function outputs - final_response = await client.chat('llama3.1', messages=messages) + final_response = await client.chat('qwen3.5:4b', messages=messages) print('Final response:', final_response.message.content) else: diff --git a/ollama/_client.py b/ollama/_client.py index 18cb0fb4..45360232 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -319,6 +319,7 @@ def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> ChatResponse: ... @overload @@ -335,6 +336,7 @@ def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> Iterator[ChatResponse]: ... def chat( @@ -350,6 +352,7 @@ def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> Union[ChatResponse, Iterator[ChatResponse]]: """ Create a chat response using the requested model. @@ -361,6 +364,8 @@ def chat( For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings stream: Whether to stream the response. format: The format of the response. + max_tool_calls: If set to a positive int, automatically execute tool calls in a loop + up to this many iterations. None (default) disables auto-execution. Example: def add_two_numbers(a: int, b: int) -> int: @@ -376,7 +381,11 @@ def add_two_numbers(a: int, b: int) -> int: ''' return a + b - client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...]) + # Manual tool handling: + client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...]) + + # Auto tool execution (max 10 iterations): + client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10) Raises `RequestError` if a model is not provided. @@ -384,24 +393,47 @@ def add_two_numbers(a: int, b: int) -> int: Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator. """ - return self._request( - ChatResponse, - 'POST', - '/api/chat', - json=ChatRequest( - model=model, - messages=list(_copy_messages(messages)), - tools=list(_copy_tools(tools)), + # MARK: standard path (no auto tool execution) + if stream or not max_tool_calls: + return self._request( + ChatResponse, + 'POST', + '/api/chat', + json=ChatRequest( + model=model, + messages=list(_copy_messages(messages)), + tools=list(_copy_tools(tools)), + stream=stream, + think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, + format=format, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, - think=think, - logprobs=logprobs, - top_logprobs=top_logprobs, - format=format, - options=options, - keep_alive=keep_alive, - ).model_dump(exclude_none=True), - stream=stream, - ) + ) + + # MARK: auto tool execution loop + tool_map = {f.__name__: f for f in (tools or []) if callable(f)} + msgs = list(messages or []) + + for _ in range(max_tool_calls): + response = self._request( + ChatResponse, 'POST', '/api/chat', + json=ChatRequest( + model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)), + stream=False, think=think, format=format, options=options, keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=False, + ) + if not response.message.tool_calls: + return response + msgs.append(response.message) + for tc in response.message.tool_calls: + output = _exec_tool(tool_map, tc) + msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name}) + + raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations') def embed( self, @@ -951,6 +983,7 @@ async def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> ChatResponse: ... @overload @@ -967,6 +1000,7 @@ async def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> AsyncIterator[ChatResponse]: ... async def chat( @@ -982,6 +1016,7 @@ async def chat( format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + max_tool_calls: Optional[int] = None, ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: """ Create a chat response using the requested model. @@ -993,6 +1028,8 @@ async def chat( For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings stream: Whether to stream the response. format: The format of the response. + max_tool_calls: If set to a positive int, automatically execute tool calls in a loop + up to this many iterations. None (default) disables auto-execution. Example: def add_two_numbers(a: int, b: int) -> int: @@ -1008,7 +1045,11 @@ def add_two_numbers(a: int, b: int) -> int: ''' return a + b - await client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...]) + # Manual tool handling: + await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...]) + + # Auto tool execution (max 10 iterations): + await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10) Raises `RequestError` if a model is not provided. @@ -1016,25 +1057,47 @@ def add_two_numbers(a: int, b: int) -> int: Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator. """ - - return await self._request( - ChatResponse, - 'POST', - '/api/chat', - json=ChatRequest( - model=model, - messages=list(_copy_messages(messages)), - tools=list(_copy_tools(tools)), + # MARK: standard path (no auto tool execution) + if stream or not max_tool_calls: + return await self._request( + ChatResponse, + 'POST', + '/api/chat', + json=ChatRequest( + model=model, + messages=list(_copy_messages(messages)), + tools=list(_copy_tools(tools)), + stream=stream, + think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, + format=format, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, - think=think, - logprobs=logprobs, - top_logprobs=top_logprobs, - format=format, - options=options, - keep_alive=keep_alive, - ).model_dump(exclude_none=True), - stream=stream, - ) + ) + + # MARK: auto tool execution loop + tool_map = {f.__name__: f for f in (tools or []) if callable(f)} + msgs = list(messages or []) + + for _ in range(max_tool_calls): + response = await self._request( + ChatResponse, 'POST', '/api/chat', + json=ChatRequest( + model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)), + stream=False, think=think, format=format, options=options, keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=False, + ) + if not response.message.tool_calls: + return response + msgs.append(response.message) + for tc in response.message.tool_calls: + output = _exec_tool(tool_map, tc) + msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name}) + + raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations') async def embed( self, @@ -1330,6 +1393,15 @@ def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool) +def _exec_tool(tool_map: dict, tc: Message.ToolCall) -> str: + """Execute a tool call, return result as string.""" + fn = tool_map.get(tc.function.name) + if not fn: + return json.dumps({'error': f'Tool {tc.function.name} not found'}) + output = fn(**tc.function.arguments) + return output if isinstance(output, str) else json.dumps(output, default=str) + + def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: if isinstance(s, (str, Path)): try: