Skip to content

Commit e2c2f7c

Browse files
authored
Move yield of metrics chunk after generation chunk (#216)
## Move yield of metrics chunk after generation chunk - when using mistral and streaming is enabled,the final chunk includes a stop_reason. There is nothing to say this final chunk doesn't also include some generated text. The existing implementation would result in that final chunk never getting sent back - this update moves the yield of the metrics chunk after the generation chunk - also included a change to include invocation metrics for cohere models Closes #215
1 parent 6ad78b7 commit e2c2f7c

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,17 @@ def prepare_output_stream(
369369
):
370370
return
371371

372-
elif (
372+
generation_chunk = _stream_response_to_generation_chunk(
373+
chunk_obj,
374+
provider=provider,
375+
output_key=output_key,
376+
messages_api=messages_api,
377+
coerce_content_to_string=coerce_content_to_string,
378+
)
379+
if generation_chunk:
380+
yield generation_chunk
381+
382+
if (
373383
provider == "mistral"
374384
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
375385
):
@@ -384,18 +394,6 @@ def prepare_output_stream(
384394
yield _get_invocation_metrics_chunk(chunk_obj)
385395
return
386396

387-
generation_chunk = _stream_response_to_generation_chunk(
388-
chunk_obj,
389-
provider=provider,
390-
output_key=output_key,
391-
messages_api=messages_api,
392-
coerce_content_to_string=coerce_content_to_string,
393-
)
394-
if generation_chunk:
395-
yield generation_chunk
396-
else:
397-
continue
398-
399397
@classmethod
400398
async def aprepare_output_stream(
401399
cls,

libs/aws/tests/unit_tests/llms/test_bedrock.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ def test__human_assistant_format() -> None:
270270
{"chunk": {"bytes": b'{"text": " you"}'}},
271271
]
272272

273+
MOCK_STREAMING_RESPONSE_MISTRAL = [
274+
{"chunk": {"bytes": b'{"outputs": [{"text": "Thank","stop_reason": null}]}'}},
275+
{"chunk": {"bytes": b'{"outputs": [{"text": "you.","stop_reason": "stop"}]}'}},
276+
]
277+
273278

274279
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
275280
for item in MOCK_STREAMING_RESPONSE:
@@ -331,6 +336,12 @@ def mistral_response():
331336
return response
332337

333338

339+
@pytest.fixture
340+
def mistral_streaming_response():
341+
response = dict(body=MOCK_STREAMING_RESPONSE_MISTRAL)
342+
return response
343+
344+
334345
@pytest.fixture
335346
def cohere_response():
336347
body = MagicMock()
@@ -412,6 +423,18 @@ def test_prepare_output_for_mistral(mistral_response):
412423
assert result["stop_reason"] is None
413424

414425

426+
def test_prepare_output_stream_for_mistral(mistral_streaming_response) -> None:
427+
results = [
428+
chunk.text
429+
for chunk in LLMInputOutputAdapter.prepare_output_stream(
430+
"mistral", mistral_streaming_response
431+
)
432+
]
433+
434+
assert results[0] == "Thank"
435+
assert results[1] == "you."
436+
437+
415438
def test_prepare_output_for_cohere(cohere_response):
416439
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
417440
assert result["text"] == "This is the Cohere output text."

0 commit comments

Comments
 (0)