Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,18 @@ def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str
for message in chat_request.messages:
if message.role not in ("system", "developer"):
continue
if not isinstance(message.content, str):
raise TypeError(f"System message content must be a string, got {type(message.content).__name__}")
system_prompts.append({"text": message.content})
if isinstance(message.content, str):
system_prompts.append({"text": message.content})
elif isinstance(message.content, list):
# Handle list-format content (e.g., from prompt caching)
for part in message.content:
if hasattr(part, "text"):
system_prompts.append({"text": part.text})
# Bedrock tagged unions require cachePoint as a SEPARATE block
if hasattr(part, "cache_control") and part.cache_control is not None:
system_prompts.append({"cachePoint": {"type": "default"}})
else:
raise TypeError(f"System message content must be a string or list, got {type(message.content).__name__}")

if not system_prompts:
return system_prompts
Expand Down Expand Up @@ -1190,11 +1199,10 @@ def _parse_content_parts(
content_parts = []
for part in message.content:
if isinstance(part, TextContent):
content_parts.append(
{
"text": part.text,
}
)
content_parts.append({"text": part.text})
# Bedrock tagged unions require cachePoint as a SEPARATE block
if hasattr(part, "cache_control") and part.cache_control is not None:
content_parts.append({"cachePoint": {"type": "default"}})
elif isinstance(part, ImageContent):
if not self.is_supported_modality(model_id, modality="IMAGE"):
raise HTTPException(
Expand Down
9 changes: 7 additions & 2 deletions src/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ class ToolCall(BaseModel):
function: ResponseFunction


class CacheControl(BaseModel):
type: str = "ephemeral"


class TextContent(BaseModel):
type: Literal["text"] = "text"
text: str
cache_control: CacheControl | None = None


class ImageUrl(BaseModel):
Expand All @@ -53,7 +58,7 @@ class ToolContent(BaseModel):
class SystemMessage(BaseModel):
name: str | None = None
role: Literal["system"] = "system"
content: str
content: str | list[TextContent]


class UserMessage(BaseModel):
Expand All @@ -78,7 +83,7 @@ class ToolMessage(BaseModel):
class DeveloperMessage(BaseModel):
name: str | None = None
role: Literal["developer"] = "developer"
content: str
content: str | list[TextContent]


class Function(BaseModel):
Expand Down