Skip to content

Commit

Permalink
Move yield of metrics chunk after generation chunk (#216)
Browse files Browse the repository at this point in the history
## Move yield of metrics chunk after generation chunk
- when using mistral and streaming is enabled,the final chunk includes a
stop_reason. There is nothing to say this final chunk doesn't also
include some generated text. The existing implementation would result in
that final chunk never getting sent back
- this update moves the yield of the metrics chunk after the generation
chunk
- also included a change to include invocation metrics for cohere models

Closes #215
  • Loading branch information
ihmaws authored Oct 2, 2024
1 parent 6ad78b7 commit e2c2f7c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
24 changes: 11 additions & 13 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,17 @@ def prepare_output_stream(
):
return

elif (
generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk

if (
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
Expand All @@ -384,18 +394,6 @@ def prepare_output_stream(
yield _get_invocation_metrics_chunk(chunk_obj)
return

generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
coerce_content_to_string=coerce_content_to_string,
)
if generation_chunk:
yield generation_chunk
else:
continue

@classmethod
async def aprepare_output_stream(
cls,
Expand Down
23 changes: 23 additions & 0 deletions libs/aws/tests/unit_tests/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def test__human_assistant_format() -> None:
{"chunk": {"bytes": b'{"text": " you"}'}},
]

MOCK_STREAMING_RESPONSE_MISTRAL = [
{"chunk": {"bytes": b'{"outputs": [{"text": "Thank","stop_reason": null}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": "you.","stop_reason": "stop"}]}'}},
]


async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
for item in MOCK_STREAMING_RESPONSE:
Expand Down Expand Up @@ -331,6 +336,12 @@ def mistral_response():
return response


@pytest.fixture
def mistral_streaming_response():
response = dict(body=MOCK_STREAMING_RESPONSE_MISTRAL)
return response


@pytest.fixture
def cohere_response():
body = MagicMock()
Expand Down Expand Up @@ -412,6 +423,18 @@ def test_prepare_output_for_mistral(mistral_response):
assert result["stop_reason"] is None


def test_prepare_output_stream_for_mistral(mistral_streaming_response) -> None:
results = [
chunk.text
for chunk in LLMInputOutputAdapter.prepare_output_stream(
"mistral", mistral_streaming_response
)
]

assert results[0] == "Thank"
assert results[1] == "you."


def test_prepare_output_for_cohere(cohere_response):
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
assert result["text"] == "This is the Cohere output text."
Expand Down

0 comments on commit e2c2f7c

Please sign in to comment.