Skip to content

Commit 27e005f

Browse files
committed
fix: address review comments and fix mypy errors
1 parent 5fa8688 commit 27e005f

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

src/strands/models/litellm.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) ->
118118

119119
return super().format_request_message_content(content)
120120

121-
@classmethod
122121
@override
122+
@classmethod
123123
def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]:
124124
"""Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID.
125125
@@ -143,6 +143,36 @@ def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> d
143143

144144
return tool_call
145145

146+
@staticmethod
147+
def _extract_thought_signature(data: Any) -> str | None:
148+
"""Extract thought signature from a tool call event data.
149+
150+
LiteLLM embeds Gemini thought signatures in the tool call ID using the ``__thought__`` separator.
151+
The signature may also appear in ``provider_specific_fields`` at the top level or on ``function``.
152+
153+
Args:
154+
data: Tool call event data object.
155+
156+
Returns:
157+
The extracted thought signature, or None if not present.
158+
"""
159+
psf = getattr(data, "provider_specific_fields", None) or {}
160+
if isinstance(psf, dict) and psf.get("thought_signature"):
161+
return str(psf["thought_signature"])
162+
163+
func = getattr(data, "function", None)
164+
func_psf = getattr(func, "provider_specific_fields", None) or {}
165+
if isinstance(func_psf, dict) and func_psf.get("thought_signature"):
166+
return str(func_psf["thought_signature"])
167+
168+
# Extract from encoded ID (lowest priority — used only when provider_specific_fields don't carry it)
169+
tool_call_id = getattr(data, "id", None) or ""
170+
if isinstance(tool_call_id, str) and _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
171+
_, signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
172+
return signature
173+
174+
return None
175+
146176
def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
147177
"""Handle switching to a new content stream.
148178
@@ -269,39 +299,13 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
269299
)
270300

271301
# Extract thought signature from tool call content_start events.
272-
# LiteLLM embeds Gemini thought signatures in the tool call ID using the __thought__ separator.
273-
# We extract it into reasoningSignature so the streaming layer can preserve it through to
274-
# the internal ToolUse representation. The full encoded ID is kept in toolUseId so that
275-
# tool result messages (which reference toolUseId) continue to match the assistant message.
302+
# The full encoded ID is kept in toolUseId so that tool result messages continue to match.
276303
if event["chunk_type"] == "content_start" and event.get("data_type") == "tool":
277-
data = event.get("data")
278-
tool_call_id = getattr(data, "id", None) or ""
279-
if not isinstance(tool_call_id, str):
280-
tool_call_id = ""
281-
# Also check provider_specific_fields for the signature (non-streaming responses)
282-
psf = getattr(data, "provider_specific_fields", None) or {}
283-
if isinstance(psf, dict):
284-
psf_signature = psf.get("thought_signature")
285-
else:
286-
psf_signature = None
287-
# Extract from encoded ID as fallback
288-
id_signature = None
289-
if _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id:
290-
_, id_signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1)
291-
# Also check function-level provider_specific_fields
292-
func = getattr(data, "function", None)
293-
func_psf = getattr(func, "provider_specific_fields", None) or {}
294-
if isinstance(func_psf, dict):
295-
func_signature = func_psf.get("thought_signature")
296-
else:
297-
func_signature = None
298-
299-
signature = psf_signature or func_signature or id_signature
300-
304+
signature = self._extract_thought_signature(event.get("data"))
301305
chunk = super().format_chunk(event, **kwargs)
302306
if signature:
303-
tool_use = chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {})
304-
tool_use["reasoningSignature"] = signature
307+
tool_use_dict = cast(dict, chunk.get("contentBlockStart", {}).get("start", {}).get("toolUse", {}))
308+
tool_use_dict["reasoningSignature"] = signature
305309
return chunk
306310

307311
# For all other cases, use the parent implementation

tests/strands/models/test_litellm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,25 @@ def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specif
891891
assert tool_use["toolUseId"] == "call_abc123"
892892

893893

894+
def test_format_chunk_tool_start_extracts_thought_signature_from_function_provider_specific_fields():
895+
"""Test that format_chunk extracts thought_signature from function.provider_specific_fields."""
896+
model = LiteLLMModel(model_id="test")
897+
898+
mock_data = unittest.mock.Mock()
899+
mock_data.id = "call_abc123" # No __thought__ in ID
900+
mock_data.function = unittest.mock.Mock()
901+
mock_data.function.name = "get_weather"
902+
mock_data.provider_specific_fields = None
903+
mock_data.function.provider_specific_fields = {"thought_signature": "ZnVuYy1zaWc="}
904+
905+
event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data}
906+
result = model.format_chunk(event)
907+
908+
tool_use = result["contentBlockStart"]["start"]["toolUse"]
909+
assert tool_use["reasoningSignature"] == "ZnVuYy1zaWc="
910+
assert tool_use["toolUseId"] == "call_abc123"
911+
912+
894913
def test_format_chunk_tool_start_no_thought_signature():
895914
"""Test that format_chunk works normally when no thought_signature is present."""
896915
model = LiteLLMModel(model_id="test")

0 commit comments

Comments
 (0)