Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from ..types.event_loop import Usage
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import MetadataEvent, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ..types.tools import ToolChoice, ToolSpec, ToolUse
from ._validation import validate_config_keys
from .openai import OpenAIModel

logger = logging.getLogger(__name__)

# Separator used by LiteLLM to embed thought signatures inside tool call IDs.
# See: https://ai.google.dev/gemini-api/docs/thought-signatures
_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"

T = TypeVar("T", bound=BaseModel)


Expand Down Expand Up @@ -114,6 +118,61 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) ->

return super().format_request_message_content(content)

@override
@classmethod
def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]:
"""Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID.

Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible
interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When
``reasoningSignature`` is present and the tool call ID does not already contain the separator, this
method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request.

Args:
tool_use: Tool use requested by the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
LiteLLM compatible tool call dict with thought signature encoded in the ID when present.
"""
tool_call = super().format_request_message_tool_call(tool_use, **kwargs)

reasoning_signature = tool_use.get("reasoningSignature")
if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]:
tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}"

return tool_call

@staticmethod
def _extract_thought_signature(data: Any) -> str | None:
"""Extract thought signature from a tool call event data.

LiteLLM embeds Gemini thought signatures in the tool call ID using the ``__thought__`` separator.
The signature may also appear in ``provider_specific_fields`` at the top level or on ``function``.

Args:
data: Tool call event data object.

Returns:
The extracted thought signature, or None if not present.
"""
psf = getattr(data, "provider_specific_fields", None) or {}
if isinstance(psf, dict) and psf.get("thought_signature"):
return str(psf["thought_signature"])

func = getattr(data, "function", None)
func_psf = getattr(func, "provider_specific_fields", None) or {}
if isinstance(func_psf, dict) and func_psf.get("thought_signature"):
return str(func_psf["thought_signature"])

# Extract from encoded ID (lowest priority — used only when provider_specific_fields don't carry it)
tool_call_id = getattr(data, "id", None) or ""
if isinstance(tool_call_id, str) and _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
_, signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
return signature
Comment on lines +159 to +172
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there three ways to extract thought_signature? Are any one of these always reliable? Do we need all three?


return None

def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
"""Handle switching to a new content stream.

Expand Down Expand Up @@ -200,8 +259,9 @@ def format_request_messages(
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format a LiteLLM response event into a standardized message chunk.

This method overrides OpenAI's format_chunk to handle the metadata case
with prompt caching support. All other chunk types use the parent implementation.
Extends OpenAI's format_chunk to:
1. Handle metadata with prompt caching support.
2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models.

Args:
event: A response event from the LiteLLM model.
Expand Down Expand Up @@ -237,6 +297,17 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
usage=usage_data,
)
)

# Extract thought signature from tool call content_start events.
# The full encoded ID is kept in toolUseId so that tool result messages continue to match.
if event["chunk_type"] == "content_start" and event.get("data_type") == "tool":
signature = self._extract_thought_signature(event.get("data"))
chunk = super().format_chunk(event, **kwargs)
if signature:
tool_use_dict = cast(dict, chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {}))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does this line need to cast and do a chained get? Since we know the shape of chunk, can we just do this?

Suggested change
tool_use_dict = cast(dict, chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {}))
tool_use_dict: dict = chunk["contentBlockStart"]["start"]["toolUse"]

tool_use_dict["reasoningSignature"] = signature
return chunk

# For all other cases, use the parent implementation
return super().format_chunk(event)

Expand Down
157 changes: 157 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,160 @@ def test_format_request_messages_with_tool_calls_no_content():
},
]
assert tru_result == exp_result


# --- Thought Signature Tests ---


def test_format_chunk_tool_start_extracts_thought_signature_from_id():
"""Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw=="
# toolUseId keeps the full encoded string so tool result IDs match
assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="


def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields():
"""Test that format_chunk extracts thought_signature from provider_specific_fields."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_abc123" # No __thought__ in ID
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.function.provider_specific_fields = None
mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="}

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["reasoningSignature"] == "cHNmLXNpZw=="
assert tool_use["toolUseId"] == "call_abc123"


def test_format_chunk_tool_start_extracts_thought_signature_from_function_provider_specific_fields():
"""Test that format_chunk extracts thought_signature from function.provider_specific_fields."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_abc123" # No __thought__ in ID
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.provider_specific_fields = None
mock_data.function.provider_specific_fields = {"thought_signature": "ZnVuYy1zaWc="}

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["reasoningSignature"] == "ZnVuYy1zaWc="
assert tool_use["toolUseId"] == "call_abc123"


def test_format_chunk_tool_start_no_thought_signature():
"""Test that format_chunk works normally when no thought_signature is present."""
model = LiteLLMModel(model_id="test")

mock_data = unittest.mock.Mock()
mock_data.id = "call_plain123"
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "get_weather"
mock_data.provider_specific_fields = None
mock_data.function.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
result = model.format_chunk(event)

tool_use = result["contentBlockStart"]["start"]["toolUse"]
assert tool_use["toolUseId"] == "call_plain123"
assert "reasoningSignature" not in tool_use


def test_format_request_message_tool_call_encodes_thought_signature():
"""Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID."""
tool_use = {
"toolUseId": "call_abc123",
"name": "get_weather",
"input": {"city": "Seattle"},
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="
assert result["function"]["name"] == "get_weather"
assert result["function"]["arguments"] == '{"city": "Seattle"}'


def test_format_request_message_tool_call_skips_encoding_when_already_present():
"""Test that format_request_message_tool_call does not double-encode the signature."""
tool_use = {
"toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==",
"name": "get_weather",
"input": {"city": "Seattle"},
"reasoningSignature": "dGhpcy1pcy1hLXNpZw==",
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

# Should NOT double-encode
assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw=="


def test_format_request_message_tool_call_no_reasoning_signature():
"""Test that format_request_message_tool_call works normally without reasoningSignature."""
tool_use = {
"toolUseId": "call_plain123",
"name": "get_weather",
"input": {"city": "Seattle"},
}

result = LiteLLMModel.format_request_message_tool_call(tool_use)

assert result["id"] == "call_plain123"
assert "__thought__" not in result["id"]


def test_thought_signature_round_trip():
"""Test that thought signature is preserved through a full response -> internal -> request cycle."""
model = LiteLLMModel(model_id="test")
signature = "R2VtaW5pVGhvdWdodFNpZw=="
tool_call_id = f"call_xyz789__thought__{signature}"

# 1. Response path: format_chunk extracts the signature
mock_data = unittest.mock.Mock()
mock_data.id = tool_call_id
mock_data.function = unittest.mock.Mock()
mock_data.function.name = "current_time"
mock_data.provider_specific_fields = None
mock_data.function.provider_specific_fields = None

event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
chunk = model.format_chunk(event)
tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"]
assert tool_use_data["reasoningSignature"] == signature

# 2. Simulate internal storage (streaming layer stores reasoningSignature)
internal_tool_use = {
"toolUseId": tool_use_data["toolUseId"],
"name": tool_use_data["name"],
"input": {"timezone": "UTC"},
"reasoningSignature": tool_use_data["reasoningSignature"],
}

# 3. Request path: format_request_message_tool_call re-encodes the signature
tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use)
assert "__thought__" in tool_call["id"]
assert signature in tool_call["id"]
Loading