Skip to content

Commit 138dd7a

Browse files
committed
fix(bedrock): restore locking for extra_headers isolation
1 parent 1c8fb67 commit 138dd7a

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ async def _call_bedrock(
673673
params: ConverseRequestTypeDef,
674674
stream: bool,
675675
) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
676-
return await self._call_bedrock_unlocked(params=params, stream=stream)
676+
async with self._extra_headers_lock:
677+
return await self._call_bedrock_unlocked(params=params, stream=stream)
677678

678679
async def _call_with_extra_headers(
679680
self,

tests/models/test_bedrock.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import date, datetime, timezone
55
from types import SimpleNamespace
66
from typing import Any
7+
from unittest.mock import AsyncMock, MagicMock
78

89
import pytest
910
from typing_extensions import TypedDict
@@ -124,6 +125,88 @@ def _bedrock_model_with_client_error(error: ClientError) -> BedrockConverseModel
124125
provider=_StubBedrockProvider(_StubBedrockClient(error)),
125126
)
126127

128+
129+
def test_register_extra_headers_injects_headers():
130+
client = MagicMock()
131+
client.meta.events = MagicMock()
132+
133+
headers = {'X-Test': '123'}
134+
135+
handler, events = BedrockConverseModel._register_extra_headers(client, headers, stream=False)
136+
137+
request = MagicMock()
138+
request.headers = {}
139+
140+
handler(request)
141+
142+
assert request.headers['X-Test'] == '123'
143+
client.meta.events.register.assert_called_once()
144+
assert events == ['before-send.bedrock-runtime.Converse']
145+
146+
147+
async def test_call_with_extra_headers_registers_and_unregisters():
148+
model = BedrockConverseModel.__new__(BedrockConverseModel)
149+
150+
model.client = MagicMock()
151+
152+
import anyio
153+
154+
model._extra_headers_lock = anyio.Lock()
155+
156+
model._register_extra_headers = MagicMock(return_value=('handler', ['event']))
157+
model._unregister_extra_headers = MagicMock()
158+
159+
model._call_bedrock_unlocked = AsyncMock(return_value='ok')
160+
161+
result = await model._call_with_extra_headers(
162+
params={'foo': 'bar'},
163+
extra_headers={'X-Test': '123'},
164+
stream=False,
165+
)
166+
167+
assert result == 'ok'
168+
model._register_extra_headers.assert_called_once()
169+
model._unregister_extra_headers.assert_called_once()
170+
171+
172+
async def test_call_with_extra_headers_unregisters_on_exception():
173+
model = BedrockConverseModel.__new__(BedrockConverseModel)
174+
175+
model.client = MagicMock()
176+
177+
import anyio
178+
179+
model._extra_headers_lock = anyio.Lock()
180+
181+
model._register_extra_headers = MagicMock(return_value=('handler', ['event']))
182+
model._unregister_extra_headers = MagicMock()
183+
184+
async def raise_error(*args, **kwargs):
185+
raise RuntimeError('fail')
186+
187+
model._call_bedrock_unlocked = raise_error
188+
189+
with pytest.raises(RuntimeError):
190+
await model._call_with_extra_headers(
191+
params={'foo': 'bar'},
192+
extra_headers={'X-Test': '123'},
193+
stream=False,
194+
)
195+
196+
model._unregister_extra_headers.assert_called_once()
197+
198+
199+
def test_register_extra_headers_stream_event_name():
200+
client = MagicMock()
201+
client.meta.events = MagicMock()
202+
203+
headers = {'X-Test': '123'}
204+
205+
handler, events = BedrockConverseModel._register_extra_headers(client, headers, stream=True)
206+
207+
assert events == ['before-send.bedrock-runtime.ConverseStream']
208+
209+
127210
async def test_bedrock_model(allow_model_requests: None, bedrock_provider: BedrockProvider):
128211
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
129212
assert model.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com'
@@ -168,6 +251,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro
168251
]
169252
)
170253

254+
171255
@pytest.mark.vcr()
172256
async def test_bedrock_model_usage_limit_exceeded(
173257
allow_model_requests: None,

0 commit comments

Comments
 (0)