|
4 | 4 | from datetime import date, datetime, timezone |
5 | 5 | from types import SimpleNamespace |
6 | 6 | from typing import Any |
| 7 | +from unittest.mock import AsyncMock, MagicMock |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 | from typing_extensions import TypedDict |
@@ -124,6 +125,88 @@ def _bedrock_model_with_client_error(error: ClientError) -> BedrockConverseModel |
124 | 125 | provider=_StubBedrockProvider(_StubBedrockClient(error)), |
125 | 126 | ) |
126 | 127 |
|
| 128 | + |
| 129 | +def test_register_extra_headers_injects_headers(): |
| 130 | + client = MagicMock() |
| 131 | + client.meta.events = MagicMock() |
| 132 | + |
| 133 | + headers = {'X-Test': '123'} |
| 134 | + |
| 135 | + handler, events = BedrockConverseModel._register_extra_headers(client, headers, stream=False) |
| 136 | + |
| 137 | + request = MagicMock() |
| 138 | + request.headers = {} |
| 139 | + |
| 140 | + handler(request) |
| 141 | + |
| 142 | + assert request.headers['X-Test'] == '123' |
| 143 | + client.meta.events.register.assert_called_once() |
| 144 | + assert events == ['before-send.bedrock-runtime.Converse'] |
| 145 | + |
| 146 | + |
| 147 | +async def test_call_with_extra_headers_registers_and_unregisters(): |
| 148 | + model = BedrockConverseModel.__new__(BedrockConverseModel) |
| 149 | + |
| 150 | + model.client = MagicMock() |
| 151 | + |
| 152 | + import anyio |
| 153 | + |
| 154 | + model._extra_headers_lock = anyio.Lock() |
| 155 | + |
| 156 | + model._register_extra_headers = MagicMock(return_value=('handler', ['event'])) |
| 157 | + model._unregister_extra_headers = MagicMock() |
| 158 | + |
| 159 | + model._call_bedrock_unlocked = AsyncMock(return_value='ok') |
| 160 | + |
| 161 | + result = await model._call_with_extra_headers( |
| 162 | + params={'foo': 'bar'}, |
| 163 | + extra_headers={'X-Test': '123'}, |
| 164 | + stream=False, |
| 165 | + ) |
| 166 | + |
| 167 | + assert result == 'ok' |
| 168 | + model._register_extra_headers.assert_called_once() |
| 169 | + model._unregister_extra_headers.assert_called_once() |
| 170 | + |
| 171 | + |
| 172 | +async def test_call_with_extra_headers_unregisters_on_exception(): |
| 173 | + model = BedrockConverseModel.__new__(BedrockConverseModel) |
| 174 | + |
| 175 | + model.client = MagicMock() |
| 176 | + |
| 177 | + import anyio |
| 178 | + |
| 179 | + model._extra_headers_lock = anyio.Lock() |
| 180 | + |
| 181 | + model._register_extra_headers = MagicMock(return_value=('handler', ['event'])) |
| 182 | + model._unregister_extra_headers = MagicMock() |
| 183 | + |
| 184 | + async def raise_error(*args, **kwargs): |
| 185 | + raise RuntimeError('fail') |
| 186 | + |
| 187 | + model._call_bedrock_unlocked = raise_error |
| 188 | + |
| 189 | + with pytest.raises(RuntimeError): |
| 190 | + await model._call_with_extra_headers( |
| 191 | + params={'foo': 'bar'}, |
| 192 | + extra_headers={'X-Test': '123'}, |
| 193 | + stream=False, |
| 194 | + ) |
| 195 | + |
| 196 | + model._unregister_extra_headers.assert_called_once() |
| 197 | + |
| 198 | + |
| 199 | +def test_register_extra_headers_stream_event_name(): |
| 200 | + client = MagicMock() |
| 201 | + client.meta.events = MagicMock() |
| 202 | + |
| 203 | + headers = {'X-Test': '123'} |
| 204 | + |
| 205 | + handler, events = BedrockConverseModel._register_extra_headers(client, headers, stream=True) |
| 206 | + |
| 207 | + assert events == ['before-send.bedrock-runtime.ConverseStream'] |
| 208 | + |
| 209 | + |
127 | 210 | async def test_bedrock_model(allow_model_requests: None, bedrock_provider: BedrockProvider): |
128 | 211 | model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) |
129 | 212 | assert model.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com' |
@@ -168,6 +251,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro |
168 | 251 | ] |
169 | 252 | ) |
170 | 253 |
|
| 254 | + |
171 | 255 | @pytest.mark.vcr() |
172 | 256 | async def test_bedrock_model_usage_limit_exceeded( |
173 | 257 | allow_model_requests: None, |
|
0 commit comments