Skip to content

Commit 7558720

Browse files
committed
fix{drivers-prompt-amazon-bedrock): add support for reasoningContent
1 parent 4ecbadb commit 7558720

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessa
258258
content_block["text"],
259259
index=event["contentBlockStart"]["contentBlockIndex"],
260260
)
261+
if "reasoningContent" in content_block:
262+
return TextDeltaMessageContent(
263+
"",
264+
index=event["contentBlockStart"]["contentBlockIndex"],
265+
)
261266
raise ValueError(f"Unsupported message content type: {event}")
262267
if "contentBlockDelta" in event:
263268
content_block_delta = event["contentBlockDelta"]
@@ -272,5 +277,10 @@ def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessa
272277
index=content_block_delta["contentBlockIndex"],
273278
partial_input=content_block_delta["delta"]["toolUse"]["input"],
274279
)
280+
if "reasoningContent" in content_block_delta["delta"]:
281+
return TextDeltaMessageContent(
282+
content_block_delta["delta"]["reasoningContent"]["text"],
283+
index=content_block_delta["contentBlockIndex"],
284+
)
275285
raise ValueError(f"Unsupported message content type: {event}")
276286
raise ValueError(f"Unsupported message content type: {event}")

tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,87 @@ def test_verify_structured_output_strategy(self):
492492
ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output strategy."
493493
):
494494
AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native")
495+
496+
def test_try_run_with_reasoning_content(self, mocker):
497+
# Given
498+
mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse
499+
mock_converse.return_value = {
500+
"output": {
501+
"message": {
502+
"content": [
503+
{"text": "model-output"},
504+
{"reasoningContent": {"reasoningText": {"text": "thinking process"}}},
505+
]
506+
}
507+
},
508+
"usage": {"inputTokens": 5, "outputTokens": 10},
509+
}
510+
511+
driver = AmazonBedrockPromptDriver(model="ai21.j2")
512+
prompt_stack = PromptStack()
513+
prompt_stack.add_user_message("test")
514+
515+
# When
516+
message = driver.try_run(prompt_stack)
517+
518+
# Then
519+
assert len(message.value) == 2
520+
# Reasoning content should be moved to the beginning
521+
assert isinstance(message.value[0], TextArtifact)
522+
assert message.value[0].value == "thinking process"
523+
assert isinstance(message.value[1], TextArtifact)
524+
assert message.value[1].value == "model-output"
525+
526+
def test_try_stream_with_reasoning_content(self, mocker):
527+
# Given
528+
mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
529+
mock_converse_stream.return_value = {
530+
"stream": [
531+
{"contentBlockStart": {"contentBlockIndex": 0, "start": {"reasoningContent": {}}}},
532+
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"reasoningContent": {"text": "thinking"}}}},
533+
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"reasoningContent": {"text": " process"}}}},
534+
{"contentBlockStart": {"contentBlockIndex": 1, "start": {"text": ""}}},
535+
{"contentBlockDelta": {"contentBlockIndex": 1, "delta": {"text": "model-output"}}},
536+
{"metadata": {"usage": {"inputTokens": 5, "outputTokens": 10}}},
537+
]
538+
}
539+
540+
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True)
541+
prompt_stack = PromptStack()
542+
prompt_stack.add_user_message("test")
543+
544+
# When
545+
stream = driver.try_stream(prompt_stack)
546+
events = list(stream)
547+
548+
# Then
549+
assert len(events) == 6 # reasoning start + 2 reasoning deltas + text start + text delta + metadata
550+
551+
# First event is the reasoning content start
552+
assert isinstance(events[0].content, TextDeltaMessageContent)
553+
assert events[0].content.text == ""
554+
assert events[0].content.index == 0
555+
556+
# Second event is reasoning content delta
557+
assert isinstance(events[1].content, TextDeltaMessageContent)
558+
assert events[1].content.text == "thinking"
559+
assert events[1].content.index == 0
560+
561+
# Third event is reasoning content delta
562+
assert isinstance(events[2].content, TextDeltaMessageContent)
563+
assert events[2].content.text == " process"
564+
assert events[2].content.index == 0
565+
566+
# Fourth event is text content start
567+
assert isinstance(events[3].content, TextDeltaMessageContent)
568+
assert events[3].content.text == ""
569+
assert events[3].content.index == 1
570+
571+
# Fifth event is text content delta
572+
assert isinstance(events[4].content, TextDeltaMessageContent)
573+
assert events[4].content.text == "model-output"
574+
assert events[4].content.index == 1
575+
576+
# Sixth event is metadata with usage
577+
assert events[5].usage.input_tokens == 5
578+
assert events[5].usage.output_tokens == 10

0 commit comments

Comments
 (0)