Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion strands-py/src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent, Usage
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand All @@ -31,6 +31,15 @@
class LlamaAPIModel(Model):
"""Llama API model provider implementation."""

OVERFLOW_MESSAGES = {
"this model's maximum context length is",
"exceed context limit",
"model's maximum context limit",
"is longer than the model's context length",
"prompt is too long",
"too many tokens",
}

class LlamaConfig(BaseModelConfig, total=False):
"""Configuration options for Llama API models.

Expand Down Expand Up @@ -368,6 +377,10 @@ async def stream(
response = self.client.chat.completions.create(**request)
except llama_api_client.RateLimitError as e:
raise ModelThrottledException(str(e)) from e
except llama_api_client.BadRequestError as e:
if any(message in str(e).lower() for message in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(e)) from e
raise

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down
11 changes: 10 additions & 1 deletion strands-py/src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._defaults import resolve_config_metadata
Expand All @@ -37,6 +37,13 @@ class MistralModel(Model):
- System prompts
"""

OVERFLOW_MESSAGES = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue (verification): The substring matching is the crux of this fix — if these strings don't match what the provider actually returns, the except branch becomes silent dead code and the bug persists undetected. The unit tests assert against these same hardcoded strings, so they'd pass even if the strings are wrong.

Suggestion: Confirm these strings against real provider error text (an integ test that intentionally overflows the context, or a captured real error in the PR description). Worth noting where each string came from, since several read like reasonable guesses (e.g. "input too large", "too many tokens").

"prompt is too long",
"too large for model",
"maximum context length",
"input too large",
}

class MistralConfig(BaseModelConfig, total=False):
"""Configuration parameters for Mistral models.

Expand Down Expand Up @@ -501,6 +508,8 @@ async def stream(
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage})

except Exception as e:
if any(message in str(e).lower() for message in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(e)) from e
if "rate" in str(e).lower() or "429" in str(e):
raise ModelThrottledException(str(e)) from e
raise
Expand Down
37 changes: 26 additions & 11 deletions strands-py/src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand All @@ -34,6 +35,11 @@ class OllamaModel(Model):
- Tool/function calling
"""

OVERFLOW_MESSAGES = {
"exceeds the available context size",
"requested context size too large",
}

class OllamaConfig(BaseModelConfig, total=False):
"""Configuration parameters for Ollama models.

Expand Down Expand Up @@ -322,20 +328,29 @@ async def stream(
event = None

client = ollama.AsyncClient(self.host, **self.client_args)
response = await client.chat(**request)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
# Ollama issues the request lazily, so overflow can surface at chat() or during iteration.
try:
response = await client.chat(**request)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

async for event in response:
for tool_call in event.message.tool_calls or []:
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call})
tool_requested = True
async for event in response:
for tool_call in event.message.tool_calls or []:
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call})
tool_requested = True

yield self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": event.message.content})
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": event.message.content}
)
except ollama.ResponseError as error:
if any(message in str(error).lower() for message in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error
raise

stop_reason = "tool_use" if tool_requested else (event.done_reason if event else None)

Expand Down
15 changes: 14 additions & 1 deletion strands-py/src/strands/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand All @@ -29,6 +29,15 @@
class WriterModel(Model):
"""Writer API model provider implementation."""

OVERFLOW_MESSAGES = {
"this model's maximum context length is",
"exceed context limit",
"model's maximum context limit",
"is longer than the model's context length",
"prompt is too long",
"too many tokens",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: This OVERFLOW_MESSAGES set is byte-for-byte identical to the one added in llamaapi.py. Two copies of the same substring list will drift over time as one provider's error strings are tuned and the other is forgotten.

Suggestion: If LlamaAPI and Writer genuinely share the same OpenAI-compatible error vocabulary, consider hoisting the shared strings to a single module-level constant (mirroring how bedrock.py and openai.py keep these lists as module constants) rather than duplicating per class.

}

class WriterConfig(BaseModelConfig, total=False):
"""Configuration options for Writer API.

Expand Down Expand Up @@ -397,6 +406,10 @@ async def stream(
response = await self.client.chat.chat(**request)
except writerai.RateLimitError as e:
raise ModelThrottledException(str(e)) from e
except writerai.BadRequestError as e:
if any(message in str(e).lower() for message in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(e)) from e
raise

yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})
Expand Down
29 changes: 29 additions & 0 deletions strands-py/tests/strands/models/test_llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
import unittest.mock

import llama_api_client
import pytest

import strands
from strands.models.llamaapi import LlamaAPIModel
from strands.types.exceptions import ContextWindowOverflowException


@pytest.fixture
Expand Down Expand Up @@ -547,3 +549,30 @@ def test_format_request_filters_location_source_document(model, caplog):
assert len(user_content) == 1
assert user_content[0]["type"] == "text"
assert "Location sources are not supported by LlamaAPI" in caplog.text


@pytest.mark.parametrize(
"overflow_message",
[
"This model's maximum context length is 128000 tokens",
"prompt is too long",
"too many tokens in request",
],
)
@pytest.mark.asyncio
async def test_stream_context_overflow_error(overflow_message, model, messages, alist):
error = llama_api_client.BadRequestError(overflow_message, response=unittest.mock.Mock(), body=None)
with unittest.mock.patch.object(model.client.chat.completions, "create", side_effect=error):
with pytest.raises(ContextWindowOverflowException) as exc_info:
await alist(model.stream(messages))

assert overflow_message in str(exc_info.value)
assert exc_info.value.__cause__ == error


@pytest.mark.asyncio
async def test_stream_non_overflow_bad_request_propagates(model, messages, alist):
error = llama_api_client.BadRequestError("invalid 'model' parameter", response=unittest.mock.Mock(), body=None)
with unittest.mock.patch.object(model.client.chat.completions, "create", side_effect=error):
with pytest.raises(llama_api_client.BadRequestError, match="invalid 'model' parameter"):
await alist(model.stream(messages))
23 changes: 22 additions & 1 deletion strands-py/tests/strands/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import strands
from strands.models.mistral import MistralModel
from strands.types.exceptions import ModelThrottledException
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException


@pytest.fixture
Expand Down Expand Up @@ -643,6 +643,27 @@ async def test_stream_other_error(mistral_client, model, alist):
await alist(model.stream(messages))


@pytest.mark.parametrize(
"overflow_message",
[
"The prompt is too long: 219245, model maximum context length: 196608",
"Prompt contains 152960 tokens and 0 draft tokens, too large for model with 131072 maximum context length",
"Input too large: couldn't fit with truncation",
],
)
@pytest.mark.asyncio
async def test_stream_context_overflow_error(overflow_message, mistral_client, model, alist):
error = Exception(overflow_message)
mistral_client.chat.stream_async.side_effect = error

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(ContextWindowOverflowException) as exc_info:
await alist(model.stream(messages))

assert overflow_message in str(exc_info.value)
assert exc_info.value.__cause__ == error


@pytest.mark.asyncio
async def test_structured_output_success(mistral_client, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Extract data"}]}]
Expand Down
43 changes: 43 additions & 0 deletions strands-py/tests/strands/models/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import re
import unittest.mock

import ollama
import pydantic
import pytest

import strands
from strands.models.ollama import OllamaModel
from strands.types.content import Messages
from strands.types.exceptions import ContextWindowOverflowException


@pytest.fixture
Expand Down Expand Up @@ -741,3 +743,44 @@ def test_format_request_uses_tool_name_not_tool_use_id(model, model_id):
# The function name in the request must come from "name", not "toolUseId"
assert tool_call["function"]["name"] == "calculator"
assert tool_call["function"]["name"] != "unique-id-abc-123"


@pytest.mark.asyncio
async def test_stream_context_overflow_error_on_request(ollama_client, model, alist):
error = ollama.ResponseError("request exceeds the available context size (2048 tokens), try increasing it")
ollama_client.chat = unittest.mock.AsyncMock(side_effect=error)

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(ContextWindowOverflowException) as exc_info:
await alist(model.stream(messages))

assert "exceeds the available context size" in str(exc_info.value)
assert exc_info.value.__cause__ == error


@pytest.mark.asyncio
async def test_stream_context_overflow_error_during_iteration(ollama_client, model, alist):
error = ollama.ResponseError("requested context size too large for model")

async def overflowing_stream():
raise error
yield # pragma: no cover - generator never yields

ollama_client.chat = unittest.mock.AsyncMock(return_value=overflowing_stream())

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(ContextWindowOverflowException) as exc_info:
await alist(model.stream(messages))

assert "requested context size too large" in str(exc_info.value)
assert exc_info.value.__cause__ == error


@pytest.mark.asyncio
async def test_stream_non_overflow_response_error_propagates(ollama_client, model, alist):
error = ollama.ResponseError("model 'm1' not found")
ollama_client.chat = unittest.mock.AsyncMock(side_effect=error)

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(ollama.ResponseError, match="not found"):
await alist(model.stream(messages))
33 changes: 33 additions & 0 deletions strands-py/tests/strands/models/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Any

import pytest
import writerai

import strands
from strands.models.writer import WriterModel
from strands.types.exceptions import ContextWindowOverflowException


@pytest.fixture
Expand Down Expand Up @@ -566,3 +568,34 @@ def test_format_request_filters_location_source_document(model, caplog):
assert len(user_content) == 1
assert user_content[0]["type"] == "text"
assert "Location sources are not supported by Writer" in caplog.text


@pytest.mark.parametrize(
"overflow_message",
[
"This model's maximum context length is 32768 tokens",
"prompt is too long",
"too many tokens in request",
],
)
@pytest.mark.asyncio
async def test_stream_context_overflow_error(overflow_message, writer_client, model, alist):
error = writerai.BadRequestError(overflow_message, response=unittest.mock.Mock(), body=None)
writer_client.chat.chat = unittest.mock.AsyncMock(side_effect=error)

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(ContextWindowOverflowException) as exc_info:
await alist(model.stream(messages))

assert overflow_message in str(exc_info.value)
assert exc_info.value.__cause__ == error


@pytest.mark.asyncio
async def test_stream_non_overflow_bad_request_propagates(writer_client, model, alist):
error = writerai.BadRequestError("invalid 'model' parameter", response=unittest.mock.Mock(), body=None)
writer_client.chat.chat = unittest.mock.AsyncMock(side_effect=error)

messages = [{"role": "user", "content": [{"text": "test"}]}]
with pytest.raises(writerai.BadRequestError, match="invalid 'model' parameter"):
await alist(model.stream(messages))
Loading