Skip to content

Commit 1b61a6f

Browse files
authored
Prevent sending thinking blocks with no signature with ChatBedrockConverse (#408)
Fixes #401 Claude 3.7 still works because signatures are always returned with those thinking blocks. DeepSeek thinking blocks do not contain signatures and the ValidationError indicates a signature must be included with thinking blocks. According to Converse API docs: > "Within the reasoningText field, the text fields describes the reasoning. The signature field is a hash of all the messages in the conversation and is a safeguard against tampering of the reasoning used by the model. You must include the signature and all previous messages in subsequent Converse requests. If any of the messages are changed, the response throws an error." https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html @3coins I posted a separate PR #407 to make reviewing easier. Happy to combine if preferred.
1 parent 6c67d96 commit 1b61a6f

File tree

2 files changed

+512
-124
lines changed

2 files changed

+512
-124
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,12 @@ class Joke(BaseModel):
460460
def set_disable_streaming(cls, values: Dict) -> Any:
461461
model_id = values.get("model_id", values.get("model"))
462462
model_parts = model_id.split(".")
463-
464-
# Extract provider from the model_id (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
465-
provider = values.get("provider") or (model_parts[-2] if len(model_parts) > 1 else model_parts[0])
463+
464+
# Extract provider from the model_id
465+
# (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
466+
provider = values.get("provider") or (
467+
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
468+
)
466469
values["provider"] = provider
467470

468471
model_id_lower = model_id.lower()
@@ -471,26 +474,46 @@ def set_disable_streaming(cls, values: Dict) -> Any:
471474
# Here we check based on the updated AWS documentation.
472475
if (
473476
# AI21 Jamba 1.5 models
474-
(provider == "ai21" and "jamba-1-5" in model_id_lower) or
477+
(provider == "ai21" and "jamba-1-5" in model_id_lower)
478+
or
475479
# Some Amazon Nova models
476-
(provider == "amazon" and any(x in model_id_lower for x in ["nova-lite", "nova-micro", "nova-pro"])) or
480+
(
481+
provider == "amazon"
482+
and any(
483+
x in model_id_lower for x in ["nova-lite", "nova-micro", "nova-pro"]
484+
)
485+
)
486+
or
477487
# Anthropic Claude 3 and newer models
478-
(provider == "anthropic" and "claude-3" in model_id_lower) or
488+
(provider == "anthropic" and "claude-3" in model_id_lower)
489+
or
479490
# Cohere Command R models
480491
(provider == "cohere" and "command-r" in model_id_lower)
481492
):
482493
streaming_support = True
483494
elif (
484495
# AI21 Jamba-Instruct model
485-
(provider == "ai21" and "jamba-instruct" in model_id_lower) or
496+
(provider == "ai21" and "jamba-instruct" in model_id_lower)
497+
or
486498
# Amazon Titan Text models
487-
(provider == "amazon" and "titan-text" in model_id_lower) or
499+
(provider == "amazon" and "titan-text" in model_id_lower)
500+
or
488501
# Anthropic older Claude models (Claude 2, Claude 2.1, Claude Instant)
489-
(provider == "anthropic" and any(x in model_id_lower for x in ["claude-v2", "claude-instant"])) or
502+
(
503+
provider == "anthropic"
504+
and any(x in model_id_lower for x in ["claude-v2", "claude-instant"])
505+
)
506+
or
490507
# Cohere Command (non-R) models
491-
(provider == "cohere" and "command" in model_id_lower and "command-r" not in model_id_lower) or
508+
(
509+
provider == "cohere"
510+
and "command" in model_id_lower
511+
and "command-r" not in model_id_lower
512+
)
513+
or
492514
# All Meta Llama models
493-
(provider == "meta") or
515+
(provider == "meta")
516+
or
494517
# All Mistral models
495518
(provider == "mistral") or
496519
# DeepSeek-R1 models
@@ -501,8 +524,10 @@ def set_disable_streaming(cls, values: Dict) -> Any:
501524
streaming_support = False
502525

503526
# Set the disable_streaming flag accordingly:
504-
# - If streaming is supported (plain streaming), we want streaming enabled (i.e. disable_streaming == False).
505-
# - If the model supports streaming only in non-tool mode ("no_tools"), then we must force disable streaming when tools are used.
527+
# - If streaming is supported (plain streaming),
528+
# we want streaming enabled (i.e. disable_streaming == False).
529+
# - If the model supports streaming only in non-tool mode ("no_tools"),
530+
# then we must force disable streaming when tools are used.
506531
# - Otherwise, if streaming is not supported, we set disable_streaming to True.
507532
if "disable_streaming" not in values:
508533
if not streaming_support:
@@ -514,7 +539,6 @@ def set_disable_streaming(cls, values: Dict) -> Any:
514539

515540
return values
516541

517-
518542
@model_validator(mode="after")
519543
def validate_environment(self) -> Self:
520544
"""Validate that AWS credentials to and python package exists in environment."""
@@ -811,10 +835,13 @@ def _converse_params(
811835
"modelId": modelId or self.model_id,
812836
"inferenceConfig": inferenceConfig,
813837
"toolConfig": toolConfig,
814-
"additionalModelRequestFields": additionalModelRequestFields
815-
or self.additional_model_request_fields,
816-
"additionalModelResponseFieldPaths": additionalModelResponseFieldPaths
817-
or self.additional_model_response_field_paths,
838+
"additionalModelRequestFields": (
839+
additionalModelRequestFields or self.additional_model_request_fields
840+
),
841+
"additionalModelResponseFieldPaths": (
842+
additionalModelResponseFieldPaths
843+
or self.additional_model_response_field_paths
844+
),
818845
"guardrailConfig": guardrailConfig or self.guardrail_config,
819846
"performanceConfig": performanceConfig or self.performance_config,
820847
"requestMetadata": requestMetadata or self.request_metadata,
@@ -1082,28 +1109,30 @@ def _lc_content_to_bedrock(
10821109
elif block["type"] == "guard_content":
10831110
bedrock_content.append({"guardContent": {"text": {"text": block["text"]}}})
10841111
elif block["type"] == "thinking":
1085-
bedrock_content.append(
1086-
{
1087-
"reasoningContent": {
1088-
"reasoningText": {
1089-
"text": block.get("thinking", ""),
1090-
"signature": block.get("signature", "")
1112+
if block.get("signature", ""):
1113+
bedrock_content.append(
1114+
{
1115+
"reasoningContent": {
1116+
"reasoningText": {
1117+
"text": block.get("thinking", ""),
1118+
"signature": block.get("signature", ""),
1119+
}
10911120
}
10921121
}
1093-
}
1094-
)
1122+
)
10951123
elif block["type"] == "reasoning_content":
10961124
reasoning_content = block.get("reasoningContent", {})
1097-
bedrock_content.append(
1098-
{
1099-
"reasoningContent": {
1100-
"reasoningText": {
1101-
"text": reasoning_content.get("text", ""),
1102-
"signature": reasoning_content.get("signature", "")
1125+
if reasoning_content.get("signature", ""):
1126+
bedrock_content.append(
1127+
{
1128+
"reasoningContent": {
1129+
"reasoningText": {
1130+
"text": reasoning_content.get("text", ""),
1131+
"signature": reasoning_content.get("signature", ""),
1132+
}
11031133
}
11041134
}
1105-
}
1106-
)
1135+
)
11071136
else:
11081137
raise ValueError(f"Unsupported content block type:\n{block}")
11091138
# drop empty text blocks
@@ -1203,7 +1232,7 @@ def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
12031232
"type": "reasoning_content",
12041233
"reasoning_content": {
12051234
"type": "text",
1206-
"text": reasoning_dict.get("text")
1235+
"text": reasoning_dict.get("text"),
12071236
},
12081237
}
12091238
)
@@ -1213,7 +1242,7 @@ def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
12131242
"type": "reasoning_content",
12141243
"reasoning_content": {
12151244
"type": "signature",
1216-
"signature": reasoning_dict.get("signature")
1245+
"signature": reasoning_dict.get("signature"),
12171246
},
12181247
}
12191248
)

0 commit comments

Comments
 (0)