Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,144 @@
from datetime import datetime
from typing import Any

from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingCallbackT
from haystack.tools import Tool, Toolset
from haystack.components.generators.chat import openai as openai_module
from haystack.components.generators.utils import _serialize_object
from haystack.dataclasses import (
ComponentInfo,
FinishReason,
StreamingCallbackT,
StreamingChunk,
ToolCallDelta,
)
from haystack.tools import (
Tool,
Toolset,
)
from haystack.utils import Secret
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice


# TODO: remove the following function and the monkey patch after haystack-ai==2.23.0, that ships this fix in the base
# class
def _convert_chat_completion_chunk_to_streaming_chunk(
chunk: ChatCompletionChunk, previous_chunks: list[StreamingChunk], component_info: ComponentInfo | None = None
) -> StreamingChunk:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.

:param chunk: The chunk returned by the OpenAI API.
:param previous_chunks: A list of previously received StreamingChunks.
:param component_info: An optional `ComponentInfo` object containing information about the component that
generated the chunk, such as the component name and type.

:returns:
A StreamingChunk object representing the content of the chunk from the OpenAI API.
"""
finish_reason_mapping: dict[str, FinishReason] = {
"stop": "stop",
"length": "length",
"content_filter": "content_filter",
"tool_calls": "tool_calls",
"function_call": "tool_calls",
}
# On very first chunk so len(previous_chunks) == 0, the Choices field only provides role info (e.g. "assistant")
# Choices is empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
return StreamingChunk(
content="",
component_info=component_info,
# Index is None since it's only set to an int when a content block is present
index=None,
finish_reason=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(), # noqa: DTZ005
"usage": _serialize_object(chunk.usage),
},
)

choice: ChunkChoice = chunk.choices[0]

# create a list of ToolCallDelta objects from the tool calls
if choice.delta and choice.delta.tool_calls:
tool_calls_deltas = []
for tool_call in choice.delta.tool_calls:
function = tool_call.function
tool_calls_deltas.append(
ToolCallDelta(
index=tool_call.index,
id=tool_call.id,
tool_name=function.name if function else None,
arguments=function.arguments if function and function.arguments else None,
)
)
chunk_message = StreamingChunk(
content=choice.delta.content or "",
component_info=component_info,
# We adopt the first tool_calls_deltas.index as the overall index of the chunk.
index=tool_calls_deltas[0].index,
tool_calls=tool_calls_deltas,
start=tool_calls_deltas[0].tool_name is not None,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(), # noqa: DTZ005
"usage": _serialize_object(chunk.usage),
},
)
return chunk_message

# On very first chunk the choice field only provides role info (e.g. "assistant") so we set index to None
# We set all chunks missing the content field to index of None. E.g. can happen if chunk only contains finish
# reason.
if choice.delta and (choice.delta.content is None or choice.delta.role is not None):
resolved_index = None
else:
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
# Anthropic Claude
resolved_index = 0

# Initialize meta dictionary
meta = {
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls if choice.delta and choice.delta.tool_calls else None,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(), # noqa: DTZ005
"usage": _serialize_object(chunk.usage),
}

# check if logprobs are present
# logprobs are returned only for text content
logprobs = _serialize_object(choice.logprobs) if choice.logprobs else None
if logprobs:
meta["logprobs"] = logprobs

content = ""
if choice.delta and choice.delta.content:
content = choice.delta.content

chunk_message = StreamingChunk(
content=content,
component_info=component_info,
index=resolved_index,
# The first chunk is always a start message chunk that only contains role information, so if we reach here
# and previous_chunks is length 1 then this is the start of text content.
start=len(previous_chunks) == 1,
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
meta=meta,
)
return chunk_message


# monkey patch the OpenAIChatGenerator to use our own _convert_chat_completion_chunk_to_streaming_chunk
openai_module._convert_chat_completion_chunk_to_streaming_chunk = _convert_chat_completion_chunk_to_streaming_chunk


class CometAPIChatGenerator(OpenAIChatGenerator):
Expand Down
11 changes: 0 additions & 11 deletions integrations/cometapi/tests/test_cometapi_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
from haystack.tools import Tool
from haystack.utils.auth import Secret
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
Expand Down Expand Up @@ -257,16 +256,6 @@ def test_live_run(self):
assert "gpt-5-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

@pytest.mark.skipif(
not os.environ.get("COMET_API_KEY", None),
reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = CometAPIChatGenerator(model="something-obviously-wrong")
with pytest.raises(OpenAIError):
component.run(chat_messages)

@pytest.mark.skipif(
not os.environ.get("COMET_API_KEY", None),
reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.",
Expand Down
Loading