diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 0425145..0173d66 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -369,7 +369,17 @@ def prepare_output_stream( ): return - elif ( + generation_chunk = _stream_response_to_generation_chunk( + chunk_obj, + provider=provider, + output_key=output_key, + messages_api=messages_api, + coerce_content_to_string=coerce_content_to_string, + ) + if generation_chunk: + yield generation_chunk + + if ( provider == "mistral" and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" ): @@ -384,18 +394,6 @@ def prepare_output_stream( yield _get_invocation_metrics_chunk(chunk_obj) return - generation_chunk = _stream_response_to_generation_chunk( - chunk_obj, - provider=provider, - output_key=output_key, - messages_api=messages_api, - coerce_content_to_string=coerce_content_to_string, - ) - if generation_chunk: - yield generation_chunk - else: - continue - @classmethod async def aprepare_output_stream( cls, diff --git a/libs/aws/tests/unit_tests/llms/test_bedrock.py b/libs/aws/tests/unit_tests/llms/test_bedrock.py index 955fc51..ff15e6f 100644 --- a/libs/aws/tests/unit_tests/llms/test_bedrock.py +++ b/libs/aws/tests/unit_tests/llms/test_bedrock.py @@ -270,6 +270,11 @@ def test__human_assistant_format() -> None: {"chunk": {"bytes": b'{"text": " you"}'}}, ] +MOCK_STREAMING_RESPONSE_MISTRAL = [ + {"chunk": {"bytes": b'{"outputs": [{"text": "Thank","stop_reason": null}]}'}}, + {"chunk": {"bytes": b'{"outputs": [{"text": "you.","stop_reason": "stop"}]}'}}, +] + async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]: for item in MOCK_STREAMING_RESPONSE: @@ -331,6 +336,12 @@ def mistral_response(): return response +@pytest.fixture +def mistral_streaming_response(): + response = dict(body=MOCK_STREAMING_RESPONSE_MISTRAL) + return response + + @pytest.fixture def cohere_response(): body = MagicMock() @@ -412,6 +423,18 @@ def test_prepare_output_for_mistral(mistral_response): assert result["stop_reason"] is None +def test_prepare_output_stream_for_mistral(mistral_streaming_response) -> None: + results = [ + chunk.text + for chunk in LLMInputOutputAdapter.prepare_output_stream( + "mistral", mistral_streaming_response + ) + ] + + assert results[0] == "Thank" + assert results[1] == "you." + + def test_prepare_output_for_cohere(cohere_response): result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response) assert result["text"] == "This is the Cohere output text."