Skip to content

Commit 195bc2b

Browse files
authored
fix{drivers-prompt-amazon-bedrock): add support for reasoningContent (#2043)
* fix{drivers-prompt-amazon-bedrock): add support for `reasoningContent` * remove from contentStartBlock * tests
1 parent 4ecbadb commit 195bc2b

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,5 +272,10 @@ def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessa
272272
index=content_block_delta["contentBlockIndex"],
273273
partial_input=content_block_delta["delta"]["toolUse"]["input"],
274274
)
275+
if "reasoningContent" in content_block_delta["delta"]:
276+
return TextDeltaMessageContent(
277+
content_block_delta["delta"]["reasoningContent"]["text"],
278+
index=content_block_delta["contentBlockIndex"],
279+
)
275280
raise ValueError(f"Unsupported message content type: {event}")
276281
raise ValueError(f"Unsupported message content type: {event}")

tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,109 @@ def test_verify_structured_output_strategy(self):
492492
ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output strategy."
493493
):
494494
AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native")
495+
496+
def test_try_run_with_reasoning_content(self, mocker):
497+
mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse
498+
mock_converse.return_value = {
499+
"output": {
500+
"message": {
501+
"content": [
502+
{"text": "model-output"},
503+
{"reasoningContent": {"reasoningText": {"text": "thinking process"}}},
504+
]
505+
}
506+
},
507+
"usage": {"inputTokens": 5, "outputTokens": 10},
508+
}
509+
510+
driver = AmazonBedrockPromptDriver(model="ai21.j2")
511+
prompt_stack = PromptStack()
512+
prompt_stack.add_user_message("test")
513+
message = driver.try_run(prompt_stack)
514+
515+
assert len(message.value) == 2
516+
517+
assert isinstance(message.value[0], TextArtifact)
518+
assert message.value[0].value == "thinking process"
519+
assert isinstance(message.value[1], TextArtifact)
520+
assert message.value[1].value == "model-output"
521+
522+
def test_try_stream_with_reasoning_content(self, mocker):
523+
# Given
524+
mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
525+
mock_converse_stream.return_value = {
526+
"stream": [
527+
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"reasoningContent": {"text": "thinking"}}}},
528+
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"reasoningContent": {"text": " process"}}}},
529+
{"contentBlockStart": {"contentBlockIndex": 1, "start": {"text": ""}}},
530+
{"contentBlockDelta": {"contentBlockIndex": 1, "delta": {"text": "model-output"}}},
531+
{"metadata": {"usage": {"inputTokens": 5, "outputTokens": 10}}},
532+
]
533+
}
534+
535+
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True)
536+
prompt_stack = PromptStack()
537+
prompt_stack.add_user_message("test")
538+
539+
# When
540+
stream = driver.try_stream(prompt_stack)
541+
events = list(stream)
542+
543+
# Then
544+
assert len(events) == 5 # 2 reasoning deltas + text start + text delta + metadata
545+
546+
# First event is reasoning content delta
547+
assert isinstance(events[0].content, TextDeltaMessageContent)
548+
assert events[0].content.text == "thinking"
549+
assert events[0].content.index == 0
550+
551+
# Second event is reasoning content delta
552+
assert isinstance(events[1].content, TextDeltaMessageContent)
553+
assert events[1].content.text == " process"
554+
assert events[1].content.index == 0
555+
556+
# Third event is text content start
557+
assert isinstance(events[2].content, TextDeltaMessageContent)
558+
assert events[2].content.text == ""
559+
assert events[2].content.index == 1
560+
561+
# Fourth event is text content delta
562+
assert isinstance(events[3].content, TextDeltaMessageContent)
563+
assert events[3].content.text == "model-output"
564+
assert events[3].content.index == 1
565+
566+
# Fifth event is metadata with usage
567+
assert events[4].usage.input_tokens == 5
568+
assert events[4].usage.output_tokens == 10
569+
570+
def test_try_stream_unsupported_content_block_delta_type(self, mocker):
571+
mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
572+
mock_converse_stream.return_value = {
573+
"stream": [
574+
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"unsupportedType": {"data": "value"}}}},
575+
]
576+
}
577+
578+
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True)
579+
prompt_stack = PromptStack()
580+
prompt_stack.add_user_message("test")
581+
582+
stream = driver.try_stream(prompt_stack)
583+
with pytest.raises(ValueError, match="Unsupported message content type"):
584+
list(stream)
585+
586+
def test_try_stream_unsupported_content_block_start_type(self, mocker):
587+
mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream
588+
mock_converse_stream.return_value = {
589+
"stream": [
590+
{"contentBlockStart": {"contentBlockIndex": 0, "start": {"unsupportedType": {"data": "value"}}}},
591+
]
592+
}
593+
594+
driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True)
595+
prompt_stack = PromptStack()
596+
prompt_stack.add_user_message("test")
597+
598+
stream = driver.try_stream(prompt_stack)
599+
with pytest.raises(ValueError, match="Unsupported message content type"):
600+
list(stream)

0 commit comments

Comments
 (0)