Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,10 @@ class CachePoint:

Supported by:

* Anthropic (automatically omitted for Bedrock, as it does not support explicit TTL). See https://docs.claude.com/en/docs/build-with-claude/prompt-caching#1-hour-cache-duration for more information."""
* Anthropic — see https://docs.claude.com/en/docs/build-with-claude/prompt-caching#1-hour-cache-duration for more information.
* Amazon Bedrock (Converse API) — see https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html for more information.

Note: TTL is automatically omitted when using `AnthropicModel` with `AsyncAnthropicBedrock`."""


UploadedFileProviderName: TypeAlias = Literal['anthropic', 'openai', 'google-gla', 'google-vertex', 'bedrock', 'xai']
Expand Down
76 changes: 50 additions & 26 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from dataclasses import dataclass, field
from datetime import datetime
from itertools import count
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast, overload
from urllib.parse import parse_qs, urlparse

import anyio.to_thread
from botocore.exceptions import ClientError
from typing_extensions import ParamSpec, assert_never
from typing_extensions import NotRequired, ParamSpec, TypedDict, assert_never

from pydantic_ai import (
AudioUrl,
Expand Down Expand Up @@ -87,6 +87,18 @@
_SUPPORTED_VIDEO_FORMATS = ('mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp')
_SUPPORTED_DOCUMENT_FORMATS = ('pdf', 'txt', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'md')

BedrockPromptCacheTTL = Literal['5m', '1h']
BedrockPromptCacheSetting: TypeAlias = bool | BedrockPromptCacheTTL


class _BedrockCachePointBlock(TypedDict):
type: Literal['default']
ttl: NotRequired[BedrockPromptCacheTTL]


class _BedrockCachePoint(TypedDict):
cachePoint: _BedrockCachePointBlock


def _make_image_block(format: str, source: DocumentSourceTypeDef) -> ContentBlockUnionTypeDef:
if format not in _SUPPORTED_IMAGE_FORMATS:
Expand Down Expand Up @@ -201,6 +213,7 @@ def _parse_s3_source(url: str) -> DocumentSourceTypeDef:

def _insert_cache_point_before_trailing_documents(
content: list[Any],
cache_point: _BedrockCachePoint,
*,
raise_if_cannot_insert: bool = False,
) -> bool:
Expand All @@ -212,6 +225,7 @@ def _insert_cache_point_before_trailing_documents(

Args:
content: The content list to modify in place.
cache_point: The cache point block to insert.
raise_if_cannot_insert: If True, raises UserError when cache point cannot be inserted
(e.g., when the message contains only documents/videos). If False, silently skips.

Expand All @@ -235,11 +249,11 @@ def _insert_cache_point_before_trailing_documents(
prev_block = content[trailing_start - 1]
if isinstance(prev_block, dict) and 'cachePoint' in prev_block:
return False
content.insert(trailing_start, {'cachePoint': {'type': 'default'}})
content.insert(trailing_start, cache_point)
return True
elif trailing_start is None:
# No trailing document/video content, append cache point at the end
content.append({'cachePoint': {'type': 'default'}})
content.append(cache_point)
return True
else:
# trailing_start == 0, can't insert at start
Expand Down Expand Up @@ -297,22 +311,22 @@ class BedrockModelSettings(ModelSettings, total=False):
See more about it on <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>.
"""

bedrock_cache_tool_definitions: bool
bedrock_cache_tool_definitions: BedrockPromptCacheSetting
"""Whether to add a cache point after the last tool definition.

When enabled, the last tool in the `tools` array will include a `cachePoint`, allowing Bedrock to cache tool
definitions and reduce costs for compatible models.
See https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html for more information.
"""

bedrock_cache_instructions: bool
bedrock_cache_instructions: BedrockPromptCacheSetting
"""Whether to add a cache point after the system prompt blocks.

When enabled, an extra `cachePoint` is appended to the system prompt so Bedrock can cache system instructions.
See https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html for more information.
"""

bedrock_cache_messages: bool
bedrock_cache_messages: BedrockPromptCacheSetting
"""Convenience setting to enable caching for the last user message.

When enabled, this automatically adds a cache point to the last content block
Expand Down Expand Up @@ -665,12 +679,9 @@ def _map_tool_config(
return None

profile = BedrockModelProfile.from_profile(self.profile)
if (
model_settings
and model_settings.get('bedrock_cache_tool_definitions')
and profile.bedrock_supports_tool_caching
):
tools.append({'cachePoint': {'type': 'default'}})
if cache_tool_definitions := (model_settings or {}).get('bedrock_cache_tool_definitions'):
if profile.bedrock_supports_tool_caching:
tools.append(cast('ToolTypeDef', self._get_cache_point(cache_tool_definitions)))

tool_choice: ToolChoiceTypeDef
if not model_request_parameters.allow_text_output:
Expand Down Expand Up @@ -706,9 +717,7 @@ async def _map_messages( # noqa: C901
if part.content: # pragma: no branch
system_prompt.append({'text': part.content})
elif isinstance(part, UserPromptPart):
bedrock_messages.extend(
await self._map_user_prompt(part, document_count, profile.bedrock_supports_prompt_caching)
)
bedrock_messages.extend(await self._map_user_prompt(part, document_count, profile))
elif isinstance(part, ToolReturnPart):
assert part.tool_call_id is not None
bedrock_messages.append(
Expand Down Expand Up @@ -827,14 +836,18 @@ async def _map_messages( # noqa: C901
if instructions := self._get_instructions(messages, model_request_parameters):
system_prompt.append({'text': instructions})

if system_prompt and settings.get('bedrock_cache_instructions') and profile.bedrock_supports_prompt_caching:
system_prompt.append({'cachePoint': {'type': 'default'}})
if system_prompt and (cache_instructions := settings.get('bedrock_cache_instructions')):
if profile.bedrock_supports_prompt_caching:
system_prompt.append(cast('SystemContentBlockTypeDef', self._get_cache_point(cache_instructions)))

if processed_messages and settings.get('bedrock_cache_messages') and profile.bedrock_supports_prompt_caching:
last_user_content = self._get_last_user_message_content(processed_messages)
if last_user_content is not None:
# Note: _get_last_user_message_content ensures content doesn't already end with a cachePoint.
_insert_cache_point_before_trailing_documents(last_user_content)
if processed_messages and (cache_messages := settings.get('bedrock_cache_messages')):
if profile.bedrock_supports_prompt_caching:
last_user_content = self._get_last_user_message_content(processed_messages)
if last_user_content is not None:
# Note: `_get_last_user_message_content` ensures content doesn't already end with a `cachePoint`.
_insert_cache_point_before_trailing_documents(
last_user_content, self._get_cache_point(cache_messages)
)

return system_prompt, processed_messages

Expand Down Expand Up @@ -868,7 +881,7 @@ async def _map_user_prompt( # noqa: C901
self,
part: UserPromptPart,
document_count: Iterator[int],
supports_prompt_caching: bool,
profile: BedrockModelProfile,
) -> list[MessageUnionTypeDef]:
content: list[ContentBlockUnionTypeDef] = []
if isinstance(part.content, str):
Expand Down Expand Up @@ -938,15 +951,19 @@ async def _map_user_prompt( # noqa: C901
else:
content.append(_make_document_block(f'Document {next(document_count)}', format, source))
elif isinstance(item, CachePoint):
if not supports_prompt_caching:
if not profile.bedrock_supports_prompt_caching:
# Silently skip CachePoint for models that don't support prompt caching
continue
if not content or 'cachePoint' in content[-1]:
raise UserError(
'CachePoint cannot be the first content in a user message - there must be previous content to cache when using Bedrock. '
'To cache system instructions or tool definitions, use the `bedrock_cache_instructions` or `bedrock_cache_tool_definitions` settings instead.'
)
_insert_cache_point_before_trailing_documents(content, raise_if_cannot_insert=True)
_insert_cache_point_before_trailing_documents(
content,
BedrockConverseModel._get_cache_point(item.ttl),
raise_if_cannot_insert=True,
)
else:
assert_never(item)
return [{'role': 'user', 'content': content}]
Expand All @@ -957,6 +974,13 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
}

@staticmethod
def _get_cache_point(cache_setting: BedrockPromptCacheSetting) -> _BedrockCachePoint:
cache_point: _BedrockCachePointBlock = {'type': 'default'}
if isinstance(cache_setting, str):
cache_point['ttl'] = cache_setting
return {'cachePoint': cache_point}

@staticmethod
def _limit_cache_points(
system_prompt: list[SystemContentBlockTypeDef],
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ xai = ["xai-sdk>=1.5.0"]
groq = ["groq>=0.25.0"]
openrouter = ["openai>=2.8.0"]
mistral = ["mistralai>=1.9.11"]
bedrock = ["boto3>=1.42.14"]
bedrock = ["boto3>=1.42.63"]
huggingface = ["huggingface-hub>=1.3.4,<2.0.0"]
sentence-transformers = ["sentence-transformers>=5.2.0; python_version < '3.14'"]
# 3.14 pin removable once voyageai 0.3.8 drops: https://pypi.org/project/voyageai/#history
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ dev = [
"pytest-pretty>=1.3.0",
"pytest-recording>=0.13.2",
"diff-cover>=9.2.0",
"boto3-stubs[bedrock-runtime]>=1.42.13",
"boto3-stubs[bedrock-runtime]>=1.42.63",
"strict-no-cover @ git+https://github.com/pydantic/strict-no-cover.git@7fc59da2c4dff919db2095a0f0e47101b657131d",
"pytest-xdist>=3.6.1",
# Needed for PyCharm users
Expand Down
92 changes: 84 additions & 8 deletions tests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ async def test_bedrock_cache_messages_no_duplicate_with_explicit_cache_point(
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'Process this document:'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{
'document': {
'name': 'Document 1',
Expand Down Expand Up @@ -2213,7 +2213,7 @@ async def test_bedrock_cache_messages_no_duplicate_when_text_ends_with_cache_poi
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'Some text content'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
]
)

Expand All @@ -2238,7 +2238,7 @@ async def test_bedrock_cache_point_before_binary_content(allow_model_requests: N
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'Process the attached text file. Return the answer only.'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{
'document': {
'name': 'Document 1',
Expand Down Expand Up @@ -2274,7 +2274,7 @@ async def test_bedrock_cache_point_with_multiple_trailing_documents(
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'Process these documents.'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{
'document': {
'name': 'Document 1',
Expand Down Expand Up @@ -2333,7 +2333,7 @@ async def test_bedrock_cache_point_with_mixed_content_and_trailing_documents(
'source': {'bytes': b'\x89PNG\r\n\x1a\n'},
}
},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{
'document': {
'name': 'Document 2',
Expand Down Expand Up @@ -2427,7 +2427,7 @@ async def test_bedrock_cache_point_multiple_markers_with_documents_no_back_to_ba
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'Analyze these:'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{'document': {'name': 'Document 1', 'format': 'txt', 'source': {'bytes': b'Doc 1'}}},
{'document': {'name': 'Document 2', 'format': 'txt', 'source': {'bytes': b'Doc 2'}}},
]
Expand All @@ -2445,9 +2445,9 @@ async def test_bedrock_cache_point_multiple_markers(allow_model_requests: None,
assert bedrock_messages[0]['content'] == snapshot(
[
{'text': 'First chunk'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{'text': 'Second chunk'},
{'cachePoint': {'type': 'default'}},
{'cachePoint': {'type': 'default', 'ttl': '5m'}},
{'text': 'Question'},
]
)
Expand Down Expand Up @@ -2568,6 +2568,82 @@ async def test_bedrock_cache_messages(allow_model_requests: None, bedrock_provid
)


async def test_bedrock_cache_instructions_and_messages_with_explicit_ttl(
allow_model_requests: None, bedrock_provider: BedrockProvider
):
model = BedrockConverseModel('us.anthropic.claude-sonnet-4-5-20250929-v1:0', provider=bedrock_provider)
messages: list[ModelMessage] = [
ModelRequest(parts=[SystemPromptPart(content='System instructions to cache.'), UserPromptPart(content='Hi!')])
]
system_prompt, bedrock_messages = await model._map_messages( # pyright: ignore[reportPrivateUsage]
messages,
ModelRequestParameters(),
BedrockModelSettings(bedrock_cache_instructions='1h', bedrock_cache_messages='1h'),
)
assert system_prompt == snapshot(
[
{'text': 'System instructions to cache.'},
{'cachePoint': {'type': 'default', 'ttl': '1h'}},
]
)
assert bedrock_messages == snapshot(
[
{
'role': 'user',
'content': [
{'text': 'Hi!'},
{'cachePoint': {'type': 'default', 'ttl': '1h'}},
],
}
]
)


async def test_bedrock_cache_tool_definitions_with_explicit_ttl(
allow_model_requests: None, bedrock_provider: BedrockProvider
):
model = BedrockConverseModel('anthropic.claude-sonnet-4-5-20250929-v1:0', provider=bedrock_provider)
params = ModelRequestParameters(
function_tools=[
ToolDefinition(name='tool_one'),
ToolDefinition(name='tool_two'),
]
)
params = model.customize_request_parameters(params)
tool_config = model._map_tool_config( # pyright: ignore[reportPrivateUsage]
params,
BedrockModelSettings(bedrock_cache_tool_definitions='1h'),
)
assert tool_config and len(tool_config['tools']) == 3
assert tool_config['tools'][-1] == {'cachePoint': {'type': 'default', 'ttl': '1h'}}


async def test_bedrock_manual_cache_point_with_explicit_ttl(
allow_model_requests: None, bedrock_provider: BedrockProvider
):
model = BedrockConverseModel('us.anthropic.claude-sonnet-4-5-20250929-v1:0', provider=bedrock_provider)
messages: list[ModelMessage] = [
ModelRequest(parts=[UserPromptPart(content=['Context to cache', CachePoint(ttl='1h'), 'Question'])])
]
_, bedrock_messages = await model._map_messages( # pyright: ignore[reportPrivateUsage]
messages,
ModelRequestParameters(),
BedrockModelSettings(),
)
assert bedrock_messages == snapshot(
[
{
'role': 'user',
'content': [
{'text': 'Context to cache'},
{'cachePoint': {'type': 'default', 'ttl': '1h'}},
{'text': 'Question'},
],
}
]
)


async def test_bedrock_cache_messages_with_binary_content(
allow_model_requests: None, bedrock_provider: BedrockProvider
):
Expand Down
Loading