Skip to content

Commit ca1c347

Browse files
fix(openai-responses): surface cache read tokens in metadata chunk (#2555)
Co-authored-by: mehtarac <mehtarac@amazon.com>
1 parent 226b3ec commit ca1c347

1 file changed

Lines changed: 122 additions & 10 deletions

File tree

strands-py/tests/strands/models/test_openai_responses.py

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,13 @@ def test_format_request(model, messages, tool_specs, system_prompt):
484484
{"chunk_type": "message_stop", "data": "stop"},
485485
{"messageStop": {"stopReason": "end_turn"}},
486486
),
487-
# Metadata
487+
# Metadata - no cache tokens
488488
(
489489
{
490490
"chunk_type": "metadata",
491-
"data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150),
491+
"data": unittest.mock.Mock(
492+
input_tokens=100, output_tokens=50, total_tokens=150, input_tokens_details=None
493+
),
492494
},
493495
{
494496
"metadata": {
@@ -503,6 +505,31 @@ def test_format_request(model, messages, tool_specs, system_prompt):
503505
},
504506
},
505507
),
508+
# Metadata - with cache read tokens
509+
(
510+
{
511+
"chunk_type": "metadata",
512+
"data": unittest.mock.Mock(
513+
input_tokens=100,
514+
output_tokens=50,
515+
total_tokens=150,
516+
input_tokens_details=unittest.mock.Mock(cached_tokens=80),
517+
),
518+
},
519+
{
520+
"metadata": {
521+
"usage": {
522+
"inputTokens": 100,
523+
"outputTokens": 50,
524+
"totalTokens": 150,
525+
"cacheReadInputTokens": 80,
526+
},
527+
"metrics": {
528+
"latencyMs": 0,
529+
},
530+
},
531+
},
532+
),
506533
],
507534
)
508535
def test_format_chunk(event, exp_chunk, model):
@@ -596,7 +623,11 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
596623
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello")
597624
mock_complete_event = unittest.mock.Mock(
598625
type="response.completed",
599-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
626+
response=unittest.mock.Mock(
627+
usage=unittest.mock.Mock(
628+
input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None
629+
)
630+
),
600631
)
601632

602633
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -632,6 +663,67 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
632663
openai_client.responses.create.assert_called_once_with(**expected_request)
633664

634665

666+
@pytest.mark.asyncio
667+
async def test_stream_cache_tokens_propagated(openai_client, model, agenerator, alist):
668+
"""Cache read tokens from input_tokens_details are surfaced in the metadata event."""
669+
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi")
670+
mock_complete_event = unittest.mock.Mock(
671+
type="response.completed",
672+
response=unittest.mock.Mock(
673+
usage=unittest.mock.Mock(
674+
input_tokens=100,
675+
output_tokens=10,
676+
total_tokens=110,
677+
input_tokens_details=unittest.mock.Mock(cached_tokens=80),
678+
)
679+
),
680+
)
681+
682+
openai_client.responses.create = unittest.mock.AsyncMock(
683+
return_value=agenerator([mock_text_event, mock_complete_event])
684+
)
685+
686+
messages = [{"role": "user", "content": [{"text": "test"}]}]
687+
tru_events = await alist(model.stream(messages))
688+
689+
metadata_events = [e for e in tru_events if "metadata" in e]
690+
assert len(metadata_events) == 1
691+
usage = metadata_events[0]["metadata"]["usage"]
692+
assert usage["inputTokens"] == 100
693+
assert usage["outputTokens"] == 10
694+
assert usage["totalTokens"] == 110
695+
assert usage["cacheReadInputTokens"] == 80
696+
697+
698+
@pytest.mark.asyncio
699+
async def test_stream_no_cache_tokens_when_absent(openai_client, model, agenerator, alist):
700+
"""cacheReadInputTokens is omitted from metadata when input_tokens_details is absent."""
701+
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hi")
702+
mock_complete_event = unittest.mock.Mock(
703+
type="response.completed",
704+
response=unittest.mock.Mock(
705+
usage=unittest.mock.Mock(
706+
input_tokens=100,
707+
output_tokens=10,
708+
total_tokens=110,
709+
input_tokens_details=None,
710+
)
711+
),
712+
)
713+
714+
openai_client.responses.create = unittest.mock.AsyncMock(
715+
return_value=agenerator([mock_text_event, mock_complete_event])
716+
)
717+
718+
messages = [{"role": "user", "content": [{"text": "test"}]}]
719+
tru_events = await alist(model.stream(messages))
720+
721+
metadata_events = [e for e in tru_events if "metadata" in e]
722+
assert len(metadata_events) == 1
723+
usage = metadata_events[0]["metadata"]["usage"]
724+
assert "cacheReadInputTokens" not in usage
725+
726+
635727
@pytest.mark.asyncio
636728
async def test_stream_with_tool_calls(openai_client, model, agenerator, alist):
637729
# Mock tool call events
@@ -644,7 +736,11 @@ async def test_stream_with_tool_calls(openai_client, model, agenerator, alist):
644736
)
645737
mock_complete_event = unittest.mock.Mock(
646738
type="response.completed",
647-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
739+
response=unittest.mock.Mock(
740+
usage=unittest.mock.Mock(
741+
input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None
742+
)
743+
),
648744
)
649745

650746
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -677,7 +773,11 @@ async def test_stream_with_tool_calls_done_event(openai_client, model, agenerato
677773
)
678774
mock_complete_event = unittest.mock.Mock(
679775
type="response.completed",
680-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
776+
response=unittest.mock.Mock(
777+
usage=unittest.mock.Mock(
778+
input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None
779+
)
780+
),
681781
)
682782

683783
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -700,7 +800,7 @@ async def test_stream_response_incomplete(openai_client, model, agenerator, alis
700800
mock_incomplete_event = unittest.mock.Mock(
701801
type="response.incomplete",
702802
response=unittest.mock.Mock(
703-
usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110),
803+
usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110, input_tokens_details=None),
704804
incomplete_details=unittest.mock.Mock(reason="max_output_tokens"),
705805
),
706806
)
@@ -734,7 +834,11 @@ async def test_stream_reasoning_content(openai_client, model, agenerator, alist,
734834
mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42")
735835
mock_complete_event = unittest.mock.Mock(
736836
type="response.completed",
737-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30)),
837+
response=unittest.mock.Mock(
838+
usage=unittest.mock.Mock(
839+
input_tokens=10, output_tokens=20, total_tokens=30, input_tokens_details=None
840+
)
841+
),
738842
)
739843

740844
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -778,7 +882,11 @@ async def test_stream_citation_annotations(openai_client, model, agenerator, ali
778882
)
779883
mock_complete_event = unittest.mock.Mock(
780884
type="response.completed",
781-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
885+
response=unittest.mock.Mock(
886+
usage=unittest.mock.Mock(
887+
input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None
888+
)
889+
),
782890
)
783891

784892
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -814,7 +922,11 @@ async def test_stream_unsupported_annotation_type(openai_client, model, agenerat
814922
)
815923
mock_complete_event = unittest.mock.Mock(
816924
type="response.completed",
817-
response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)),
925+
response=unittest.mock.Mock(
926+
usage=unittest.mock.Mock(
927+
input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None
928+
)
929+
),
818930
)
819931

820932
openai_client.responses.create = unittest.mock.AsyncMock(
@@ -1283,7 +1395,7 @@ async def test_stream_stateful(openai_client, model_id, agenerator, alist):
12831395
type="response.completed",
12841396
response=unittest.mock.Mock(
12851397
id="resp_abc123",
1286-
usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15),
1398+
usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15, input_tokens_details=None),
12871399
),
12881400
),
12891401
]

0 commit comments

Comments
 (0)