|
4 | 4 | from copy import deepcopy |
5 | 5 | from typing import Any |
6 | 6 |
|
| 7 | +import pytest |
7 | 8 | import voluptuous as vol |
8 | 9 |
|
9 | 10 | from homeassistant.components import ai_task |
| 11 | +from homeassistant.components.open_responses.client import ( |
| 12 | + OpenResponsesInvalidModelError, |
| 13 | +) |
10 | 14 | from homeassistant.core import HomeAssistant |
| 15 | +from homeassistant.exceptions import HomeAssistantError |
11 | 16 | from homeassistant.helpers import selector |
12 | 17 |
|
13 | 18 | from tests.common import MockConfigEntry |
@@ -91,3 +96,26 @@ async def test_generate_structured_data( |
91 | 96 |
|
92 | 97 | assert result.data == {"characters": ["Mario", "Luigi"]} |
93 | 98 | assert calls[0]["text"]["format"]["type"] == "json_schema" |
| 99 | + |
| 100 | + |
| 101 | +async def test_generate_data_handles_invalid_model( |
| 102 | + hass: HomeAssistant, |
| 103 | + mock_config_entry: MockConfigEntry, |
| 104 | + mock_init_component: None, |
| 105 | +) -> None: |
| 106 | + """Test runtime model errors are surfaced as Home Assistant errors.""" |
| 107 | + |
| 108 | + async def stream_response(**params: Any) -> AsyncGenerator[dict[str, Any]]: |
| 109 | + if params["model"] == "open-responses-model": |
| 110 | + raise OpenResponsesInvalidModelError("missing model") |
| 111 | + yield {} |
| 112 | + |
| 113 | + mock_config_entry.runtime_data.stream_response = stream_response |
| 114 | + |
| 115 | + with pytest.raises(HomeAssistantError, match="Invalid Open Responses model"): |
| 116 | + await ai_task.async_generate_data( |
| 117 | + hass, |
| 118 | + task_name="Test Task", |
| 119 | + entity_id="ai_task.open_responses_ai_task", |
| 120 | + instructions="Generate test data", |
| 121 | + ) |
0 commit comments