Skip to content

Commit c475761

Browse files
committed
feat: Add matching support for responses API
Also send both completion and response outputs as GalileoMessages to fix validation error message
1 parent 18aef69 commit c475761

File tree

4 files changed

+643
-15
lines changed

4 files changed

+643
-15
lines changed

src/galileo/openai.py

Lines changed: 208 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pydantic import BaseModel
5151
from wrapt import wrap_function_wrapper # type: ignore[import-untyped]
5252

53-
from galileo import GalileoLogger
53+
from galileo import GalileoLogger, Message, MessageRole, ToolCall, ToolCallFunction
5454
from galileo.decorator import galileo_context
5555
from galileo.utils import _get_timestamp
5656
from galileo.utils.serialization import serialize_to_str
@@ -103,7 +103,10 @@ class OpenAiInputData:
103103
OPENAI_CLIENT_METHODS = [
104104
OpenAiModuleDefinition(
105105
module="openai.resources.chat.completions", object="Completions", method="create", type="chat", sync=True
106-
)
106+
),
107+
OpenAiModuleDefinition(
108+
module="openai.resources.responses", object="Responses", method="create", type="response", sync=True
109+
),
107110
# Eventually add more OpenAI client library methods here
108111
]
109112

@@ -153,6 +156,77 @@ def wrapper(wrapped: Callable, instance: Any, args: dict, kwargs: dict) -> Any:
153156
return _with_galileo
154157

155158

159+
def _convert_to_galileo_message(data: Any, default_role: str = "user") -> Message:
160+
"""Convert OpenAI response data to a Galileo Message object."""
161+
if hasattr(data, "type") and data.type == "function_call":
162+
tool_call = ToolCall(
163+
id=getattr(data, "call_id", ""),
164+
function=ToolCallFunction(name=getattr(data, "name", ""), arguments=getattr(data, "arguments", "")),
165+
)
166+
return Message(content="", role=MessageRole.assistant, tool_calls=[tool_call])
167+
168+
if isinstance(data, dict) and data.get("type") == "function_call_output":
169+
output = data.get("output", "")
170+
if isinstance(output, dict):
171+
import json
172+
173+
content = json.dumps(output)
174+
else:
175+
content = str(output)
176+
177+
return Message(content=content, role=MessageRole.tool, tool_call_id=data.get("call_id", ""))
178+
179+
# Handle ChatCompletionMessage objects (from completion API) and dictionary messages
180+
if (hasattr(data, "role") and hasattr(data, "content")) or isinstance(data, dict):
181+
# Extract role and content from either object type
182+
if hasattr(data, "role"):
183+
# ChatCompletionMessage object
184+
role = getattr(data, "role", default_role)
185+
content = getattr(data, "content", "")
186+
tool_calls = getattr(data, "tool_calls", None)
187+
tool_call_id = getattr(data, "tool_call_id", None)
188+
else:
189+
# Dictionary message
190+
role = data.get("role", default_role)
191+
content = data.get("content", "")
192+
tool_calls = data.get("tool_calls")
193+
tool_call_id = data.get("tool_call_id")
194+
195+
# Handle tool calls if present
196+
galileo_tool_calls = None
197+
if tool_calls:
198+
galileo_tool_calls = []
199+
for tc in tool_calls:
200+
if hasattr(tc, "function"):
201+
# ChatCompletionMessageFunctionToolCall object
202+
galileo_tool_calls.append(
203+
ToolCall(
204+
id=getattr(tc, "id", ""),
205+
function=ToolCallFunction(
206+
name=getattr(tc.function, "name", ""), arguments=getattr(tc.function, "arguments", "")
207+
),
208+
)
209+
)
210+
elif isinstance(tc, dict) and "function" in tc:
211+
# Dictionary tool call
212+
galileo_tool_calls.append(
213+
ToolCall(
214+
id=tc.get("id", ""),
215+
function=ToolCallFunction(
216+
name=tc["function"].get("name", ""), arguments=tc["function"].get("arguments", "")
217+
),
218+
)
219+
)
220+
221+
return Message(
222+
content=str(content) if content is not None else "",
223+
role=MessageRole(role),
224+
tool_calls=galileo_tool_calls,
225+
tool_call_id=tool_call_id,
226+
)
227+
return Message(content=str(data), role=MessageRole(default_role))
228+
229+
156230
def _extract_chat_response(kwargs: dict) -> dict:
157231
"""Extracts the llm output from the response."""
158232
response = {"role": kwargs.get("role")}
@@ -213,6 +287,8 @@ def _extract_input_data_from_kwargs(
213287
prompt = kwargs.get("prompt")
214288
elif resource.type == "chat":
215289
prompt = kwargs.get("messages", [])
290+
elif resource.type == "response":
291+
prompt = kwargs.get("input", "")
216292

217293
parsed_temperature = float(
218294
kwargs.get("temperature", 1) if not isinstance(kwargs.get("temperature", 1), NotGiven) else 1
@@ -283,6 +359,17 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
283359

284360
usage_dict = usage.copy() if isinstance(usage, dict) else usage.__dict__
285361

362+
# Handle Responses API field names (input_tokens/output_tokens) vs Chat Completions (prompt_tokens/completion_tokens)
363+
if "input_tokens" in usage_dict:
364+
usage_dict["prompt_tokens"] = usage_dict.pop("input_tokens")
365+
if "output_tokens" in usage_dict:
366+
usage_dict["completion_tokens"] = usage_dict.pop("output_tokens")
367+
368+
if "input_tokens_details" in usage_dict:
369+
usage_dict["prompt_tokens_details"] = usage_dict.pop("input_tokens_details")
370+
if "output_tokens_details" in usage_dict:
371+
usage_dict["completion_tokens_details"] = usage_dict.pop("output_tokens_details")
372+
286373
for tokens_details in ["prompt_tokens_details", "completion_tokens_details"]:
287374
if tokens_details in usage_dict and usage_dict[tokens_details] is not None:
288375
tokens_details_dict = (
@@ -295,6 +382,44 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
295382
return usage_dict
296383

297384

385+
def _extract_responses_output(output_items: list) -> dict:
386+
"""Extract the final message and tool calls from Responses API output items."""
387+
final_message = None
388+
tool_calls = []
389+
390+
for item in output_items:
391+
if hasattr(item, "type") and item.type == "message":
392+
final_message = {"role": getattr(item, "role", "assistant"), "content": ""}
393+
394+
content = getattr(item, "content", [])
395+
if isinstance(content, list):
396+
text_parts = []
397+
for content_item in content:
398+
if hasattr(content_item, "text"):
399+
text_parts.append(content_item.text)
400+
elif isinstance(content_item, dict) and "text" in content_item:
401+
text_parts.append(content_item["text"])
402+
final_message["content"] = "".join(text_parts)
403+
else:
404+
final_message["content"] = str(content)
405+
406+
elif hasattr(item, "type") and item.type == "function_call":
407+
tool_call = {
408+
"id": getattr(item, "id", ""),
409+
"function": {"name": getattr(item, "name", ""), "arguments": getattr(item, "arguments", "")},
410+
"type": "function",
411+
}
412+
tool_calls.append(tool_call)
413+
414+
if final_message:
415+
if tool_calls:
416+
final_message["tool_calls"] = tool_calls
417+
return final_message
418+
if tool_calls:
419+
return {"role": "assistant", "tool_calls": tool_calls}
420+
return {"role": "assistant", "content": ""}
421+
422+
298423
def _extract_data_from_default_response(resource: OpenAiModuleDefinition, response: dict[str, Any]) -> Any:
299424
if response is None:
300425
return None, "<NoneType response returned from OpenAI>", None
@@ -325,6 +450,10 @@ def _extract_data_from_default_response(resource: OpenAiModuleDefinition, respon
325450
completion = (
326451
_extract_chat_response(choice.message.__dict__) if _is_openai_v1() else choice.get("message", None)
327452
)
453+
elif resource.type == "response":
454+
# Handle Responses API structure
455+
output = response.get("output", [])
456+
completion = _extract_responses_output(output)
328457

329458
usage = _parse_usage(response.get("usage"))
330459

@@ -335,10 +464,27 @@ def _extract_streamed_openai_response(resource: OpenAiModuleDefinition, chunks:
335464
completion = defaultdict(str) if resource.type == "chat" else ""
336465
model, usage = None, None
337466

467+
# For Responses API, we just need to find the final completed event
468+
if resource.type == "response":
469+
final_response = None
470+
338471
for chunk in chunks:
339472
if _is_openai_v1():
340473
chunk = chunk.__dict__
341474

475+
if resource.type == "response":
476+
chunk_type = chunk.get("type", "")
477+
478+
if chunk_type == "response.completed":
479+
final_response = chunk.get("response")
480+
if final_response:
481+
model = getattr(final_response, "model", None)
482+
usage_obj = getattr(final_response, "usage", None)
483+
if usage_obj:
484+
usage = _parse_usage(usage_obj.__dict__ if hasattr(usage_obj, "__dict__") else usage_obj)
485+
486+
continue
487+
342488
model = model or chunk.get("model", None) or None
343489
usage = chunk.get("usage", None)
344490

@@ -414,7 +560,15 @@ def get_response_for_chat() -> Any:
414560
or None
415561
)
416562

417-
return model, get_response_for_chat() if resource.type == "chat" else completion, usage
563+
if resource.type == "chat":
564+
return model, get_response_for_chat(), usage
565+
if resource.type == "response":
566+
if final_response:
567+
output_items = getattr(final_response, "output", [])
568+
response_message = _extract_responses_output(output_items)
569+
return model, response_message, usage
570+
return model, {"role": "assistant", "content": ""}, usage
571+
return model, completion, usage
418572

419573

420574
def _is_openai_v1() -> bool:
@@ -442,7 +596,14 @@ def _wrap(
442596
else:
443597
# If we don't have an active trace, start a new trace
444598
# We will conclude it at the end
445-
galileo_logger.start_trace(input=serialize_to_str(input_data.input), name=input_data.name)
599+
if isinstance(input_data.input, list):
600+
trace_input_messages = [_convert_to_galileo_message(msg) for msg in input_data.input]
601+
else:
602+
trace_input_messages = [_convert_to_galileo_message(input_data.input)]
603+
604+
# Serialize with "messages" wrapper for UI compatibility
605+
trace_input = {"messages": [msg.model_dump(exclude_none=True) for msg in trace_input_messages]}
606+
galileo_logger.start_trace(input=serialize_to_str(trace_input), name=input_data.name)
446607
should_complete_trace = True
447608

448609
try:
@@ -476,10 +637,17 @@ def _wrap(
476637

477638
duration_ns = round((end_time - start_time).total_seconds() * 1e9)
478639

640+
if isinstance(input_data.input, list):
641+
span_input = [_convert_to_galileo_message(msg) for msg in input_data.input]
642+
else:
643+
span_input = [_convert_to_galileo_message(input_data.input)]
644+
645+
span_output = _convert_to_galileo_message(completion, "assistant")
646+
479647
# Add a span to the current trace or span (if this is a nested trace)
480648
galileo_logger.add_llm_span(
481-
input=input_data.input,
482-
output=completion,
649+
input=span_input,
650+
output=span_output,
483651
tools=input_data.tools,
484652
name=input_data.name,
485653
model=model,
@@ -496,8 +664,19 @@ def _wrap(
496664

497665
# Conclude the trace if this is the top-level call
498666
if should_complete_trace:
667+
full_conversation = []
668+
669+
if isinstance(input_data.input, list):
670+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in input_data.input])
671+
else:
672+
full_conversation.append(_convert_to_galileo_message(input_data.input))
673+
674+
full_conversation.append(span_output)
675+
676+
# Serialize with "messages" wrapper for UI compatibility
677+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
499678
galileo_logger.conclude(
500-
output=serialize_to_str(completion), duration_ns=duration_ns, status_code=status_code
679+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=status_code
501680
)
502681

503682
# we want to re-raise exception after we process openai_response
@@ -593,10 +772,17 @@ def _finalize(self) -> None:
593772
# TODO: make sure completion_start_time what we want
594773
duration_ns = round((end_time - self.completion_start_time).total_seconds() * 1e9)
595774

775+
if isinstance(self.input_data.input, list):
776+
span_input = [_convert_to_galileo_message(msg) for msg in self.input_data.input]
777+
else:
778+
span_input = [_convert_to_galileo_message(self.input_data.input)]
779+
780+
span_output = _convert_to_galileo_message(completion, "assistant")
781+
596782
# Add a span to the current trace or span (if this is a nested trace)
597783
self.logger.add_llm_span(
598-
input=self.input_data.input,
599-
output=completion,
784+
input=span_input,
785+
output=span_output,
600786
tools=self.input_data.tools,
601787
name=self.input_data.name,
602788
model=model,
@@ -611,7 +797,19 @@ def _finalize(self) -> None:
611797

612798
# Conclude the trace if this is the top-level call
613799
if self.should_complete_trace:
614-
self.logger.conclude(output=completion, duration_ns=duration_ns, status_code=self.status_code)
800+
full_conversation = []
801+
802+
if isinstance(self.input_data.input, list):
803+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in self.input_data.input])
804+
else:
805+
full_conversation.append(_convert_to_galileo_message(self.input_data.input))
806+
807+
full_conversation.append(span_output)
808+
809+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
810+
self.logger.conclude(
811+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=self.status_code
812+
)
615813

616814

617815
class OpenAIGalileo:

0 commit comments

Comments
 (0)