Skip to content

Commit d91890c

Browse files
committed
rebase?
1 parent f512d65 commit d91890c

49 files changed

Lines changed: 7738 additions & 1473 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Literal
2+
3+
from pydantic_ai.exceptions import UserError
4+
from pydantic_ai.models import ModelRequestParameters
5+
from pydantic_ai.settings import ModelSettings
6+
from pydantic_ai.tools import ToolDefinition
7+
8+
ToolChoiceValue = Literal['none', 'auto', 'required'] | list[str]
9+
"""The validated tool_choice value: a mode string or a list of specific tool names."""
10+
11+
12+
def filter_tools_for_choice(
13+
tool_choice: ToolChoiceValue | None,
14+
function_tools: list[ToolDefinition],
15+
output_tools: list[ToolDefinition],
16+
) -> list[ToolDefinition]:
17+
"""Filter tools based on the tool_choice value.
18+
19+
This is a helper function for model implementations that need to filter
20+
tools before sending them to the API. Some providers support native
21+
tool filtering (like Google's `allowed_function_names` or OpenAI's
22+
`allowed_tools`), so filtering may not always be necessary.
23+
24+
Args:
25+
tool_choice: The validated tool_choice value from `validate_tool_choice`.
26+
function_tools: The available function tools.
27+
output_tools: The available output tools (for structured output).
28+
29+
Returns:
30+
The filtered list of tools to send to the API:
31+
- None or 'auto': all tools (function + output)
32+
- 'required': only function_tools (no output tools)
33+
- 'none': only output_tools (no function tools)
34+
- list[str]: only the specified tools from both lists (no auto-inclusion)
35+
"""
36+
if tool_choice is None or tool_choice == 'auto':
37+
return [*function_tools, *output_tools]
38+
elif tool_choice == 'required':
39+
return list(function_tools)
40+
elif tool_choice == 'none':
41+
return list(output_tools)
42+
else:
43+
# list[str] - only include explicitly named tools from both lists
44+
allowed = set(tool_choice)
45+
return [t for t in [*function_tools, *output_tools] if t.name in allowed]
46+
47+
48+
def validate_tool_choice(
49+
model_settings: ModelSettings | None,
50+
model_request_parameters: ModelRequestParameters,
51+
) -> ToolChoiceValue | None:
52+
"""Validate and normalize tool_choice from model settings.
53+
54+
This is a public helper for model implementations to validate and normalize
55+
the user's `tool_choice` setting. Custom model implementations may need
56+
to call this function.
57+
58+
Args:
59+
model_settings: The model settings containing tool_choice.
60+
model_request_parameters: The request parameters containing tool definitions.
61+
62+
Returns:
63+
The normalized tool_choice value:
64+
- None if tool_choice was not set (provider uses default behavior)
65+
- 'none', 'auto', or 'required' for mode strings
66+
- list[str] for specific tool names (validated against available function and output tools)
67+
68+
Raises:
69+
UserError: If tool names in list[str] are not valid tool names.
70+
"""
71+
user_tool_choice = (model_settings or {}).get('tool_choice')
72+
73+
if user_tool_choice is None:
74+
return None
75+
76+
if user_tool_choice == 'none':
77+
return 'none'
78+
79+
if user_tool_choice in ('auto', 'required'):
80+
return user_tool_choice
81+
82+
if isinstance(user_tool_choice, list):
83+
if not user_tool_choice:
84+
return 'none'
85+
function_tool_names = {t.name for t in model_request_parameters.function_tools}
86+
output_tool_names = {t.name for t in model_request_parameters.output_tools}
87+
all_tool_names = function_tool_names | output_tool_names
88+
invalid_names = set(user_tool_choice) - all_tool_names
89+
if invalid_names:
90+
raise UserError(
91+
f'Invalid tool names in `tool_choice`: {invalid_names}. Available tools: {all_tool_names or "none"}'
92+
)
93+
return list(user_tool_choice)
94+
95+
return None # pragma: no cover

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 104 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import io
4+
import warnings
45
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
56
from contextlib import asynccontextmanager
67
from dataclasses import dataclass, field, replace
@@ -42,7 +43,15 @@
4243
from ..providers.anthropic import AsyncAnthropicClient
4344
from ..settings import ModelSettings, merge_model_settings
4445
from ..tools import ToolDefinition
45-
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
46+
from . import (
47+
Model,
48+
ModelRequestParameters,
49+
StreamedResponse,
50+
check_allow_model_requests,
51+
download_item,
52+
get_user_agent,
53+
)
54+
from ._tool_choice import filter_tools_for_choice, validate_tool_choice
4655

4756
_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
4857
'end_turn': 'stop',
@@ -386,11 +395,9 @@ async def _messages_create(
386395
This is the last step before sending the request to the API.
387396
Most preprocessing has happened in `prepare_request()`.
388397
"""
389-
tools = self._get_tools(model_request_parameters, model_settings)
398+
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
390399
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)
391400

392-
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
393-
394401
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
395402
self._limit_cache_points(system_prompt, anthropic_messages, tools)
396403
output_format = self._native_output_format(model_request_parameters)
@@ -474,11 +481,9 @@ async def _messages_count_tokens(
474481
raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.')
475482

476483
# standalone function to make it easier to override
477-
tools = self._get_tools(model_request_parameters, model_settings)
484+
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
478485
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)
479486

480-
tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)
481-
482487
system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
483488
self._limit_cache_points(system_prompt, anthropic_messages, tools)
484489
output_format = self._native_output_format(model_request_parameters)
@@ -585,22 +590,6 @@ async def _process_streamed_response(
585590
_provider_url=self._provider.base_url,
586591
)
587592

588-
def _get_tools(
589-
self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings
590-
) -> list[BetaToolUnionParam]:
591-
tools: list[BetaToolUnionParam] = [
592-
self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()
593-
]
594-
595-
# Add cache_control to the last tool if enabled
596-
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
597-
# If True, use '5m'; otherwise use the specified ttl value
598-
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
599-
last_tool = tools[-1]
600-
last_tool['cache_control'] = self._build_cache_control(ttl)
601-
602-
return tools
603-
604593
def _add_builtin_tools(
605594
self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters
606595
) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]:
@@ -664,26 +653,105 @@ def _add_builtin_tools(
664653
)
665654
return tools, mcp_servers, beta_features
666655

667-
def _infer_tool_choice(
656+
def _infer_tool_choice( # noqa: C901
668657
self,
669-
tools: list[BetaToolUnionParam],
670658
model_settings: AnthropicModelSettings,
671659
model_request_parameters: ModelRequestParameters,
672-
) -> BetaToolChoiceParam | None:
673-
if not tools:
674-
return None
675-
else:
676-
tool_choice: BetaToolChoiceParam
660+
) -> tuple[list[BetaToolUnionParam], BetaToolChoiceParam | None]:
661+
"""Determine which tools to send and the API tool_choice value.
662+
663+
Returns:
664+
A tuple of (filtered_tools, tool_choice).
665+
"""
666+
thinking_enabled = model_settings.get('anthropic_thinking') is not None
667+
function_tools = model_request_parameters.function_tools
668+
output_tools = model_request_parameters.output_tools
669+
670+
tool_choice_value = validate_tool_choice(model_settings, model_request_parameters)
671+
tool_defs_to_send = filter_tools_for_choice(tool_choice_value, function_tools, output_tools)
672+
673+
if not tool_defs_to_send:
674+
return [], None
677675

676+
# Map ToolDefinitions to Anthropic format
677+
tools: list[BetaToolUnionParam] = [self._map_tool_definition(t) for t in tool_defs_to_send]
678+
679+
# Add cache_control to the last tool if enabled
680+
if cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions'):
681+
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
682+
last_tool = tools[-1]
683+
last_tool['cache_control'] = self._build_cache_control(ttl)
684+
685+
# Check for parallel_tool_calls setting once
686+
disable_parallel: bool | None = None
687+
if 'parallel_tool_calls' in model_settings:
688+
disable_parallel = not model_settings['parallel_tool_calls']
689+
690+
tool_choice: BetaToolChoiceParam
691+
692+
if tool_choice_value is None or tool_choice_value == 'auto':
678693
if not model_request_parameters.allow_text_output:
679694
tool_choice = {'type': 'any'}
680695
else:
681696
tool_choice = {'type': 'auto'}
697+
if disable_parallel is not None:
698+
tool_choice['disable_parallel_tool_use'] = disable_parallel
699+
700+
elif tool_choice_value == 'required':
701+
if thinking_enabled:
702+
raise UserError(
703+
"Anthropic does not support `tool_choice='required'` with thinking mode. "
704+
'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
705+
)
706+
tool_choice = {'type': 'any'}
707+
if disable_parallel is not None:
708+
tool_choice['disable_parallel_tool_use'] = disable_parallel
709+
710+
elif tool_choice_value == 'none':
711+
if len(output_tools) == 0:
712+
assert model_request_parameters.allow_text_output, (
713+
'Internal error: tool_choice=none with no output tools but text output not allowed'
714+
)
715+
# BetaToolChoiceNoneParam doesn't support disable_parallel_tool_use
716+
tool_choice = {'type': 'none'}
717+
elif len(output_tools) == 1:
718+
tool_choice = {'type': 'tool', 'name': output_tools[0].name}
719+
if disable_parallel is not None:
720+
tool_choice['disable_parallel_tool_use'] = disable_parallel
721+
else:
722+
warnings.warn(
723+
'Anthropic only supports forcing a single tool. '
724+
f"Falling back to '{'auto' if model_request_parameters.allow_text_output else 'any'}' "
725+
'for multiple output tools.'
726+
)
727+
if not model_request_parameters.allow_text_output:
728+
tool_choice = {'type': 'any'}
729+
else:
730+
tool_choice = {'type': 'auto'}
731+
if disable_parallel is not None:
732+
tool_choice['disable_parallel_tool_use'] = disable_parallel
733+
734+
elif isinstance(tool_choice_value, list):
735+
# Specific tool names
736+
if thinking_enabled:
737+
raise UserError(
738+
'Anthropic does not support forcing specific tools with thinking mode. '
739+
'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
740+
)
741+
if len(tool_choice_value) == 1:
742+
tool_choice = {'type': 'tool', 'name': tool_choice_value[0]}
743+
else:
744+
warnings.warn(
745+
"Anthropic only supports forcing a single tool. Falling back to 'any' for multiple specific tools."
746+
)
747+
tool_choice = {'type': 'any'}
748+
if disable_parallel is not None:
749+
tool_choice['disable_parallel_tool_use'] = disable_parallel
682750

683-
if 'parallel_tool_calls' in model_settings:
684-
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
751+
else:
752+
assert_never(tool_choice_value)
685753

686-
return tool_choice
754+
return tools, tool_choice
687755

688756
async def _map_message( # noqa: C901
689757
self,
@@ -888,9 +956,10 @@ async def _map_message( # noqa: C901
888956
system_prompt_parts.append(instructions)
889957
system_prompt = '\n\n'.join(system_prompt_parts)
890958

959+
ttl: Literal['5m', '1h']
891960
# Add cache_control to the last message content if anthropic_cache_messages is enabled
892961
if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')):
893-
ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages
962+
ttl = '5m' if cache_messages is True else cache_messages
894963
m = anthropic_messages[-1]
895964
content = m['content']
896965
if isinstance(content, str):
@@ -910,7 +979,7 @@ async def _map_message( # noqa: C901
910979
# If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control
911980
if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')):
912981
# If True, use '5m'; otherwise use the specified ttl value
913-
ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions
982+
ttl = '5m' if cache_instructions is True else cache_instructions
914983
system_prompt_blocks = [
915984
BetaTextBlockParam(
916985
type='text',

0 commit comments

Comments
 (0)