forked from strands-agents/sdk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanthropic.py
More file actions
479 lines (395 loc) · 18 KB
/
anthropic.py
File metadata and controls
479 lines (395 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
"""Anthropic Claude model provider.
- Docs: https://docs.anthropic.com/claude/reference/getting-started-with-the-api
"""
import base64
import json
import logging
import mimetypes
from collections.abc import AsyncGenerator
from typing import Any, TypedDict, TypeVar, cast
import anthropic
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override
from ..event_loop.streaming import process_stream
from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
from ._validation import _has_location_source, validate_config_keys
from .model import Model
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class AnthropicModel(Model):
"""Anthropic model provider implementation."""
EVENT_TYPES = {
"message_start",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
}
OVERFLOW_MESSAGES = {
"prompt is too long:",
"input is too long",
"input length exceeds context window",
"input and output tokens exceed your context limit",
}
class AnthropicConfig(TypedDict, total=False):
"""Configuration options for Anthropic models.
Attributes:
max_tokens: Maximum number of tokens to generate.
model_id: Calude model ID (e.g., "claude-3-7-sonnet-latest").
For a complete list of supported models, see
https://docs.anthropic.com/en/docs/about-claude/models/all-models.
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
"""
max_tokens: Required[int]
model_id: Required[str]
params: dict[str, Any] | None
def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]):
"""Initialize provider instance.
Args:
client_args: Arguments for the underlying Anthropic client (e.g., api_key).
For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks.
**model_config: Configuration options for the Anthropic model.
"""
validate_config_keys(model_config, self.AnthropicConfig)
self.config = AnthropicModel.AnthropicConfig(**model_config)
logger.debug("config=<%s> | initializing", self.config)
client_args = client_args or {}
self.client = anthropic.AsyncAnthropic(**client_args)
@override
def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override]
"""Update the Anthropic model configuration with the provided arguments.
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.AnthropicConfig)
self.config.update(model_config)
@override
def get_config(self) -> AnthropicConfig:
"""Get the Anthropic model configuration.
Returns:
The Anthropic model configuration.
"""
return self.config
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
"""Format an Anthropic content block.
Args:
content: Message content.
Returns:
Anthropic formatted content block.
Raises:
TypeError: If the content block type cannot be converted to an Anthropic-compatible format.
"""
if "document" in content:
mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
return {
"source": {
"data": (
content["document"]["source"]["bytes"].decode("utf-8")
if mime_type == "text/plain"
else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8")
),
"media_type": mime_type,
"type": "text" if mime_type == "text/plain" else "base64",
},
"title": content["document"]["name"],
"type": "document",
}
if "image" in content:
return {
"source": {
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
"media_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"),
"type": "base64",
},
"type": "image",
}
if "reasoningContent" in content:
return {
"signature": content["reasoningContent"]["reasoningText"]["signature"],
"thinking": content["reasoningContent"]["reasoningText"]["text"],
"type": "thinking",
}
if "text" in content:
return {"text": content["text"], "type": "text"}
if "toolUse" in content:
return {
"id": content["toolUse"]["toolUseId"],
"input": content["toolUse"]["input"],
"name": content["toolUse"]["name"],
"type": "tool_use",
}
if "toolResult" in content:
return {
"content": [
self._format_request_message_content(
{"text": json.dumps(tool_result_content["json"])}
if "json" in tool_result_content
else cast(ContentBlock, tool_result_content)
)
for tool_result_content in content["toolResult"]["content"]
],
"is_error": content["toolResult"]["status"] == "error",
"tool_use_id": content["toolResult"]["toolUseId"],
"type": "tool_result",
}
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Format an Anthropic messages array.
Args:
messages: List of message objects to be processed by the model.
Returns:
An Anthropic messages array.
"""
formatted_messages = []
for message in messages:
formatted_contents: list[dict[str, Any]] = []
for content in message["content"]:
if "cachePoint" in content:
formatted_contents[-1]["cache_control"] = {"type": "ephemeral"}
continue
# Check for location sources in image, document, or video content
if _has_location_source(content):
logger.warning("Location sources are not supported by Anthropic | skipping content block")
continue
formatted_contents.append(self._format_request_message_content(content))
if formatted_contents:
formatted_messages.append({"content": formatted_contents, "role": message["role"]})
return formatted_messages
def format_request(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Format an Anthropic streaming request.
Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
Returns:
An Anthropic streaming request.
Raises:
TypeError: If a message contains a content block type that cannot be converted to an Anthropic-compatible
format.
"""
return {
"max_tokens": self.config["max_tokens"],
"messages": self._format_request_messages(messages),
"model": self.config["model_id"],
"tools": [
{
"name": tool_spec["name"],
"description": tool_spec["description"],
"input_schema": tool_spec["inputSchema"]["json"],
}
for tool_spec in tool_specs or []
],
**(self._format_tool_choice(tool_choice)),
**({"system": system_prompt} if system_prompt else {}),
**(self.config.get("params") or {}),
}
@staticmethod
def _format_tool_choice(tool_choice: ToolChoice | None) -> dict:
if tool_choice is None:
return {}
if "any" in tool_choice:
return {"tool_choice": {"type": "any"}}
elif "auto" in tool_choice:
return {"tool_choice": {"type": "auto"}}
elif "tool" in tool_choice:
return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}}
else:
return {}
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Anthropic response events into standardized message chunks.
Args:
event: A response event from the Anthropic model.
Returns:
The formatted chunk.
Raises:
RuntimeError: If chunk_type is not recognized.
This error should never be encountered as we control chunk_type in the stream method.
"""
match event["type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}
case "content_block_start":
content = event["content_block"]
if content["type"] == "tool_use":
return {
"contentBlockStart": {
"contentBlockIndex": event["index"],
"start": {
"toolUse": {
"name": content["name"],
"toolUseId": content["id"],
}
},
}
}
return {"contentBlockStart": {"contentBlockIndex": event["index"], "start": {}}}
case "content_block_delta":
delta = event["delta"]
match delta["type"]:
case "signature_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"signature": delta["signature"],
},
},
},
}
case "thinking_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"text": delta["thinking"],
},
},
},
}
case "input_json_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"toolUse": {
"input": delta["partial_json"],
},
},
},
}
case "text_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"text": delta["text"],
},
},
}
case _:
raise RuntimeError(
f"event_type=<content_block_delta>, delta_type=<{delta['type']}> | unknown type"
)
case "content_block_stop":
return {"contentBlockStop": {"contentBlockIndex": event["index"]}}
case "message_stop":
message = event["message"]
return {"messageStop": {"stopReason": message["stop_reason"]}}
case "metadata":
usage = event["usage"]
return {
"metadata": {
"usage": {
"inputTokens": usage["input_tokens"],
"outputTokens": usage["output_tokens"],
"totalTokens": usage["input_tokens"] + usage["output_tokens"],
},
"metrics": {
"latencyMs": 0, # TODO
},
}
}
case _:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")
@override
async def stream(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
*,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Anthropic model.
Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments for future extensibility.
Yields:
Formatted message chunks from the model.
Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by Anthropic.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
logger.debug("request=<%s>", request)
logger.debug("invoking model")
try:
async with self.client.messages.stream(**request) as stream:
logger.debug("got response from model")
async for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield self.format_chunk(event.model_dump())
# Prefer get_final_message() which safely handles early stream
# termination (e.g. network timeout before message_stop).
# Fall back to the last event for mock/test streams that lack
# the get_final_message API.
try:
final_message = await stream.get_final_message()
usage = final_message.usage
except (AttributeError, Exception):
usage = event.message.usage # type: ignore
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
except anthropic.BadRequestError as error:
if any(overflow_message in str(error).lower() for overflow_message in AnthropicModel.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error
raise error
logger.debug("finished streaming response from model")
@override
async def structured_output(
self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any
) -> AsyncGenerator[dict[str, T | Any], None]:
"""Get structured output from the model.
Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.
Yields:
Model events with the last being the structured output.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=cast(ToolChoice, {"any": {}}),
**kwargs,
)
async for event in process_stream(response):
yield event
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
for block in content:
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
# if the tool use name never matches, raise an error.
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
output_response = block["toolUse"]["input"]
else:
continue
if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
yield {"output": output_model(**output_response)}