Skip to content

Commit b789604

Browse files
committed
fix(open_responses): cover review edge cases
1 parent 9545b5e commit b789604

10 files changed

Lines changed: 414 additions & 27 deletions

File tree

homeassistant/components/open_responses/client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
from typing import Any
66

7-
from httpx import AsyncClient, HTTPStatusError, RequestError
7+
from httpx import AsyncClient, HTTPStatusError, RequestError, Response
88
from openresponses_types.types import CreateResponseBody
99

1010

@@ -20,6 +20,10 @@ class OpenResponsesRateLimitError(OpenResponsesError):
2020
"""Open Responses rate limit error."""
2121

2222

23+
class OpenResponsesInvalidModelError(OpenResponsesError):
24+
"""Open Responses model validation error."""
25+
26+
2327
class OpenResponsesConnectionError(OpenResponsesError):
2428
"""Open Responses connection error."""
2529

@@ -129,4 +133,23 @@ def _raise_client_error(err: HTTPStatusError) -> None:
129133
raise OpenResponsesAuthError("Authentication failed")
130134
if status_code == 429:
131135
raise OpenResponsesRateLimitError("Rate limited")
136+
if status_code == 400 and _response_error_mentions_model(err.response):
137+
raise OpenResponsesInvalidModelError("Invalid model")
132138
raise OpenResponsesConnectionError("Open Responses endpoint error")
139+
140+
141+
def _response_error_mentions_model(response: Response) -> bool:
142+
"""Return whether an error response points at the requested model."""
143+
try:
144+
body = response.json()
145+
except ValueError:
146+
return False
147+
148+
error = body.get("error") if isinstance(body, dict) else None
149+
if not isinstance(error, dict):
150+
return False
151+
152+
return any(
153+
"model" in str(error.get(key, "")).lower()
154+
for key in ("code", "param", "message")
155+
)

homeassistant/components/open_responses/config_flow.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
OpenResponsesAuthError,
3535
OpenResponsesClient,
3636
OpenResponsesConnectionError,
37+
OpenResponsesInvalidModelError,
3738
)
3839
from .const import (
3940
CONF_BASE_URL,
@@ -76,6 +77,23 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
7677
)
7778

7879

80+
def _async_update_default_subentry_models(
81+
hass: HomeAssistant, entry: ConfigEntry, model: str
82+
) -> None:
83+
"""Update generated default subentries when reauth changes the model."""
84+
old_model = entry.data[CONF_MODEL]
85+
86+
for subentry in entry.subentries.values():
87+
if subentry.data.get(CONF_MODEL) != old_model:
88+
continue
89+
90+
hass.config_entries.async_update_subentry(
91+
entry,
92+
subentry,
93+
data={**subentry.data, CONF_MODEL: model},
94+
)
95+
96+
7997
class OpenResponsesConfigFlow(ConfigFlow, domain=DOMAIN):
8098
"""Handle a config flow for Open Responses."""
8199

@@ -107,6 +125,8 @@ async def async_step_user(
107125
await validate_input(self.hass, user_input)
108126
except OpenResponsesAuthError:
109127
errors["base"] = "invalid_auth"
128+
except OpenResponsesInvalidModelError:
129+
errors[CONF_MODEL] = "invalid_model"
110130
except OpenResponsesConnectionError:
111131
errors["base"] = "cannot_connect"
112132
else:
@@ -119,8 +139,12 @@ async def async_step_user(
119139
CONF_MODEL: user_input[CONF_MODEL],
120140
}
121141
if self.source == SOURCE_REAUTH:
142+
reauth_entry = self._get_reauth_entry()
143+
_async_update_default_subentry_models(
144+
self.hass, reauth_entry, user_input[CONF_MODEL]
145+
)
122146
return self.async_update_reload_and_abort(
123-
self._get_reauth_entry(), data_updates=user_input
147+
reauth_entry, data_updates=user_input
124148
)
125149
return self.async_create_entry(
126150
title="Open Responses",

homeassistant/components/open_responses/entity.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,58 @@ def _convert_content_to_param(
168168
return messages
169169

170170

171+
async def _async_prepare_message_attachments(
172+
hass: HomeAssistant,
173+
chat_content: Iterable[conversation.Content],
174+
messages: ResponseInputParam,
175+
) -> None:
176+
"""Attach files to all matching user messages."""
177+
message_index = 0
178+
179+
for content in chat_content:
180+
if isinstance(content, conversation.ToolResultContent):
181+
message_index += 1
182+
continue
183+
184+
if (
185+
isinstance(content, conversation.AssistantContent)
186+
and isinstance(content.native, dict)
187+
and content.native.get("type")
188+
):
189+
message_index += 1
190+
continue
191+
192+
if content.content or (
193+
isinstance(content, conversation.UserContent) and content.attachments
194+
):
195+
if isinstance(content, conversation.UserContent) and content.attachments:
196+
files = await async_prepare_files_for_prompt(
197+
hass,
198+
[
199+
(attachment.path, attachment.mime_type)
200+
for attachment in content.attachments
201+
],
202+
)
203+
message = messages[message_index]
204+
assert (
205+
message["type"] == "message"
206+
and message["role"] == "user"
207+
and isinstance(message["content"], str)
208+
)
209+
message_content: ResponseInputMessageContentListParam = []
210+
if message["content"]:
211+
message_content.append(
212+
{"type": "input_text", "text": message["content"]}
213+
)
214+
message_content.extend(files)
215+
message["content"] = message_content
216+
217+
message_index += 1
218+
219+
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
220+
message_index += len(content.tool_calls)
221+
222+
171223
async def _transform_stream(
172224
chat_log: conversation.ChatLog,
173225
stream: AsyncIterable[dict[str, Any]],
@@ -307,7 +359,9 @@ async def _async_handle_chat_log(
307359
) -> None:
308360
"""Generate an answer for the chat log."""
309361
options = self.subentry.data
310-
messages = _convert_content_to_param(chat_log.content)
362+
chat_content = list(chat_log.content)
363+
messages = _convert_content_to_param(chat_content)
364+
await _async_prepare_message_attachments(self.hass, chat_content, messages)
311365

312366
model_args: dict[str, Any] = {
313367
"model": options.get(CONF_MODEL, self.entry.data[CONF_MODEL]),
@@ -329,26 +383,6 @@ async def _async_handle_chat_log(
329383
if tools:
330384
model_args["tools"] = tools
331385

332-
last_content = chat_log.content[-1]
333-
if last_content.role == "user" and last_content.attachments:
334-
files = await async_prepare_files_for_prompt(
335-
self.hass,
336-
[(a.path, a.mime_type) for a in last_content.attachments],
337-
)
338-
last_message = messages[-1]
339-
assert (
340-
last_message["type"] == "message"
341-
and last_message["role"] == "user"
342-
and isinstance(last_message["content"], str)
343-
)
344-
last_message_content: ResponseInputMessageContentListParam = []
345-
if last_message["content"]:
346-
last_message_content.append(
347-
{"type": "input_text", "text": last_message["content"]}
348-
)
349-
last_message_content.extend(files)
350-
last_message["content"] = last_message_content
351-
352386
if structure and structure_name:
353387
model_args["text"] = {
354388
"format": {

homeassistant/components/open_responses/strings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"error": {
88
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
99
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
10+
"invalid_model": "The endpoint rejected the configured model.",
1011
"unknown": "[%key:common::config_flow::error::unknown%]"
1112
},
1213
"step": {

tests/components/open_responses/conftest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,19 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
3232
version=1,
3333
subentries_data=[
3434
ConfigSubentryData(
35-
data=RECOMMENDED_CONVERSATION_OPTIONS,
35+
data={
36+
**RECOMMENDED_CONVERSATION_OPTIONS,
37+
CONF_MODEL: "open-responses-model",
38+
},
3639
subentry_type="conversation",
3740
title=DEFAULT_CONVERSATION_NAME,
3841
unique_id=None,
3942
),
4043
ConfigSubentryData(
41-
data=RECOMMENDED_AI_TASK_OPTIONS,
44+
data={
45+
**RECOMMENDED_AI_TASK_OPTIONS,
46+
CONF_MODEL: "open-responses-model",
47+
},
4248
subentry_type="ai_task_data",
4349
title=DEFAULT_AI_TASK_NAME,
4450
unique_id=None,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Tests for Open Responses AI tasks."""
2+
3+
from collections.abc import AsyncGenerator
4+
from copy import deepcopy
5+
from typing import Any
6+
7+
import voluptuous as vol
8+
9+
from homeassistant.components import ai_task
10+
from homeassistant.core import HomeAssistant
11+
from homeassistant.helpers import selector
12+
13+
from tests.common import MockConfigEntry
14+
15+
16+
def _mock_response_stream(calls: list[dict[str, Any]], text: str) -> Any:
17+
"""Return a streaming response mock."""
18+
19+
async def stream_response(**params: Any) -> AsyncGenerator[dict[str, Any]]:
20+
calls.append(deepcopy(params))
21+
yield {
22+
"type": "response.output_item.added",
23+
"item": {
24+
"type": "message",
25+
"id": "msg_1",
26+
"role": "assistant",
27+
"content": [],
28+
"status": "in_progress",
29+
},
30+
}
31+
yield {"type": "response.output_text.delta", "delta": text}
32+
yield {
33+
"type": "response.completed",
34+
"response": {"usage": {"input_tokens": 1, "output_tokens": 1}},
35+
}
36+
37+
return stream_response
38+
39+
40+
async def test_generate_data(
41+
hass: HomeAssistant,
42+
mock_config_entry: MockConfigEntry,
43+
mock_init_component: None,
44+
) -> None:
45+
"""Test plain-text AI task data generation reaches the client."""
46+
calls: list[dict[str, Any]] = []
47+
mock_config_entry.runtime_data.stream_response = _mock_response_stream(
48+
calls, "The test data"
49+
)
50+
51+
result = await ai_task.async_generate_data(
52+
hass,
53+
task_name="Test Task",
54+
entity_id="ai_task.open_responses_ai_task",
55+
instructions="Generate test data",
56+
)
57+
58+
assert result.data == "The test data"
59+
assert calls[0]["model"] == "open-responses-model"
60+
assert calls[0]["input"][-1] == {
61+
"type": "message",
62+
"role": "user",
63+
"content": "Generate test data",
64+
}
65+
66+
67+
async def test_generate_structured_data(
68+
hass: HomeAssistant,
69+
mock_config_entry: MockConfigEntry,
70+
mock_init_component: None,
71+
) -> None:
72+
"""Test structured AI task data generation reaches the client."""
73+
calls: list[dict[str, Any]] = []
74+
mock_config_entry.runtime_data.stream_response = _mock_response_stream(
75+
calls, '{"characters":["Mario","Luigi"]}'
76+
)
77+
78+
result = await ai_task.async_generate_data(
79+
hass,
80+
task_name="Character Task",
81+
entity_id="ai_task.open_responses_ai_task",
82+
instructions="Generate character data",
83+
structure=vol.Schema(
84+
{
85+
vol.Required("characters"): selector.selector(
86+
{"text": {"multiple": True}}
87+
)
88+
}
89+
),
90+
)
91+
92+
assert result.data == {"characters": ["Mario", "Luigi"]}
93+
assert calls[0]["text"]["format"]["type"] == "json_schema"

tests/components/open_responses/test_client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Tests for the Open Responses client."""
22

3+
import httpx
34
from pydantic import ValidationError
45
import pytest
56

6-
from homeassistant.components.open_responses.client import _format_request_body
7+
from homeassistant.components.open_responses.client import (
8+
OpenResponsesInvalidModelError,
9+
_format_request_body,
10+
_raise_client_error,
11+
)
712

813

914
def test_format_request_body_preserves_tool_parameters() -> None:
@@ -47,3 +52,19 @@ def test_format_request_body_validates_response_body() -> None:
4752
"stream": False,
4853
}
4954
)
55+
56+
57+
def test_raise_client_error_detects_invalid_model() -> None:
58+
"""Test model validation errors are separated from endpoint failures."""
59+
response = httpx.Response(
60+
400,
61+
json={"error": {"message": "Unknown model", "param": "model"}},
62+
request=httpx.Request("POST", "https://example.local/v1/responses"),
63+
)
64+
65+
with pytest.raises(OpenResponsesInvalidModelError):
66+
_raise_client_error(
67+
httpx.HTTPStatusError(
68+
"bad request", request=response.request, response=response
69+
)
70+
)

0 commit comments

Comments
 (0)