|
1 | | -"""The OpenAI Conversation integration.""" |
2 | | -from __future__ import annotations |
| 1 | +"""The Azure OpenAI Conversation integration.""" |
3 | 2 |
|
4 | | -from functools import partial |
5 | | -import logging |
6 | | -from typing import Literal |
| 3 | +from __future__ import annotations |
7 | 4 |
|
8 | 5 | import openai |
9 | | -from openai import error |
| 6 | +import voluptuous as vol |
10 | 7 |
|
11 | | -from homeassistant.components import conversation |
12 | 8 | from homeassistant.config_entries import ConfigEntry |
13 | | -from homeassistant.const import CONF_API_KEY, MATCH_ALL |
14 | | -from homeassistant.core import HomeAssistant |
15 | | -from homeassistant.exceptions import ConfigEntryNotReady, TemplateError |
16 | | -from homeassistant.helpers import intent, template |
17 | | -from homeassistant.util import ulid |
18 | | - |
19 | | -from .const import ( |
20 | | - CONF_CHAT_MODEL, |
21 | | - CONF_MAX_TOKENS, |
22 | | - CONF_PROMPT, |
23 | | - CONF_TEMPERATURE, |
24 | | - CONF_TOP_P, |
25 | | - DEFAULT_CHAT_MODEL, |
26 | | - DEFAULT_MAX_TOKENS, |
27 | | - DEFAULT_PROMPT, |
28 | | - DEFAULT_TEMPERATURE, |
29 | | - DEFAULT_TOP_P, |
30 | | - CONF_API_BASE, |
31 | | - CONF_API_VERSION, |
| 9 | +from homeassistant.const import CONF_API_KEY, Platform |
| 10 | +from homeassistant.core import ( |
| 11 | + HomeAssistant, |
| 12 | + ServiceCall, |
| 13 | + ServiceResponse, |
| 14 | + SupportsResponse, |
32 | 15 | ) |
| 16 | +from homeassistant.exceptions import ( |
| 17 | + ConfigEntryNotReady, |
| 18 | + HomeAssistantError, |
| 19 | + ServiceValidationError, |
| 20 | +) |
| 21 | +from homeassistant.helpers import config_validation as cv, selector |
| 22 | +from homeassistant.helpers.httpx_client import get_async_client |
| 23 | +from homeassistant.helpers.typing import ConfigType |
33 | 24 |
|
34 | | -_LOGGER = logging.getLogger(__name__) |
| 25 | +from .const import CONF_API_BASE, CONF_API_VERSION, DOMAIN, LOGGER |
35 | 26 |
|
| 27 | +SERVICE_GENERATE_IMAGE = "generate_image" |
| 28 | +PLATFORMS = (Platform.CONVERSATION,) |
| 29 | +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) |
| 30 | +type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] |
36 | 31 |
|
37 | | -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: |
38 | | - """Set up OpenAI Conversation from a config entry.""" |
39 | | - openai.api_key = entry.data[CONF_API_KEY] |
40 | | - openai.api_type = "azure" |
41 | | - openai.api_base = entry.data[CONF_API_BASE] |
42 | | - openai.api_version = entry.data[CONF_API_VERSION] |
43 | 32 |
|
44 | | - try: |
45 | | - await hass.async_add_executor_job( |
46 | | - partial(openai.Model.list, request_timeout=10) |
47 | | - ) |
48 | | - except error.AuthenticationError as err: |
49 | | - _LOGGER.error("Invalid API key: %s", err) |
50 | | - return False |
51 | | - except error.OpenAIError as err: |
52 | | - raise ConfigEntryNotReady(err) from err |
| 33 | +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: |
| 34 | + """Set up Azure OpenAI Conversation.""" |
53 | 35 |
|
54 | | - conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry)) |
55 | | - return True |
| 36 | + async def render_image(call: ServiceCall) -> ServiceResponse: |
| 37 | + """Render an image with dall-e.""" |
| 38 | + entry_id = call.data["config_entry"] |
| 39 | + entry = hass.config_entries.async_get_entry(entry_id) |
| 40 | + |
| 41 | + if entry is None or entry.domain != DOMAIN: |
| 42 | + raise ServiceValidationError( |
| 43 | + translation_domain=DOMAIN, |
| 44 | + translation_key="invalid_config_entry", |
| 45 | + translation_placeholders={"config_entry": entry_id}, |
| 46 | + ) |
56 | 47 |
|
| 48 | + client: openai.AsyncAzureOpenAI = entry.runtime_data |
57 | 49 |
|
58 | | -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: |
59 | | - """Unload OpenAI.""" |
60 | | - openai.api_key = None |
61 | | - conversation.async_unset_agent(hass, entry) |
| 50 | + try: |
| 51 | + response = await client.images.generate( |
| 52 | + model="dall-e-3", |
| 53 | + prompt=call.data["prompt"], |
| 54 | + size=call.data["size"], |
| 55 | + quality=call.data["quality"], |
| 56 | + style=call.data["style"], |
| 57 | + response_format="url", |
| 58 | + n=1, |
| 59 | + ) |
| 60 | + except openai.OpenAIError as err: |
| 61 | + raise HomeAssistantError(f"Error generating image: {err}") from err |
| 62 | + |
| 63 | + return response.data[0].model_dump(exclude={"b64_json"}) |
| 64 | + |
| 65 | + hass.services.async_register( |
| 66 | + DOMAIN, |
| 67 | + SERVICE_GENERATE_IMAGE, |
| 68 | + render_image, |
| 69 | + schema=vol.Schema( |
| 70 | + { |
| 71 | + vol.Required("config_entry"): selector.ConfigEntrySelector( |
| 72 | + { |
| 73 | + "integration": DOMAIN, |
| 74 | + } |
| 75 | + ), |
| 76 | + vol.Required("prompt"): cv.string, |
| 77 | + vol.Optional("size", default="1024x1024"): vol.In( |
| 78 | + ("1024x1024", "1024x1792", "1792x1024") |
| 79 | + ), |
| 80 | + vol.Optional("quality", default="standard"): vol.In(("standard", "hd")), |
| 81 | + vol.Optional("style", default="vivid"): vol.In(("vivid", "natural")), |
| 82 | + } |
| 83 | + ), |
| 84 | + supports_response=SupportsResponse.ONLY, |
| 85 | + ) |
62 | 86 | return True |
63 | 87 |
|
64 | 88 |
|
65 | | -class OpenAIAgent(conversation.AbstractConversationAgent): |
66 | | - """OpenAI conversation agent.""" |
67 | | - |
68 | | - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: |
69 | | - """Initialize the agent.""" |
70 | | - self.hass = hass |
71 | | - self.entry = entry |
72 | | - self.history: dict[str, list[dict]] = {} |
73 | | - |
74 | | - @property |
75 | | - def attribution(self): |
76 | | - """Return the attribution.""" |
77 | | - return { |
78 | | - "name": "Powered by Azure OpenAI", |
79 | | - "url": "https://azure.microsoft.com/products/cognitive-services/openai-service", |
80 | | - } |
81 | | - |
82 | | - @property |
83 | | - def supported_languages(self) -> list[str] | Literal["*"]: |
84 | | - """Return a list of supported languages.""" |
85 | | - return MATCH_ALL |
86 | | - |
87 | | - async def async_process( |
88 | | - self, user_input: conversation.ConversationInput |
89 | | - ) -> conversation.ConversationResult: |
90 | | - """Process a sentence.""" |
91 | | - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) |
92 | | - model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) |
93 | | - max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) |
94 | | - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) |
95 | | - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) |
96 | | - |
97 | | - if user_input.conversation_id in self.history: |
98 | | - conversation_id = user_input.conversation_id |
99 | | - messages = self.history[conversation_id] |
100 | | - else: |
101 | | - conversation_id = ulid.ulid() |
102 | | - try: |
103 | | - prompt = self._async_generate_prompt(raw_prompt) |
104 | | - except TemplateError as err: |
105 | | - _LOGGER.error("Error rendering prompt: %s", err) |
106 | | - intent_response = intent.IntentResponse(language=user_input.language) |
107 | | - intent_response.async_set_error( |
108 | | - intent.IntentResponseErrorCode.UNKNOWN, |
109 | | - f"Sorry, I had a problem with my template: {err}", |
110 | | - ) |
111 | | - return conversation.ConversationResult( |
112 | | - response=intent_response, conversation_id=conversation_id |
113 | | - ) |
114 | | - messages = [{"role": "system", "content": prompt}] |
115 | | - |
116 | | - messages.append({"role": "user", "content": user_input.text}) |
117 | | - |
118 | | - _LOGGER.debug("Prompt for %s: %s", model, messages) |
| 89 | +async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool: |
| 90 | + """Set up Azure OpenAI Conversation from a config entry.""" |
119 | 91 |
|
120 | | - try: |
121 | | - result = await openai.ChatCompletion.acreate( |
122 | | - engine=model, |
123 | | - messages=messages, |
124 | | - max_tokens=max_tokens, |
125 | | - top_p=top_p, |
126 | | - temperature=temperature, |
127 | | - user=conversation_id, |
128 | | - ) |
129 | | - except error.OpenAIError as err: |
130 | | - intent_response = intent.IntentResponse(language=user_input.language) |
131 | | - intent_response.async_set_error( |
132 | | - intent.IntentResponseErrorCode.UNKNOWN, |
133 | | - f"Sorry, I had a problem talking to OpenAI: {err}", |
134 | | - ) |
135 | | - return conversation.ConversationResult( |
136 | | - response=intent_response, conversation_id=conversation_id |
137 | | - ) |
| 92 | + client = openai.AsyncAzureOpenAI( |
| 93 | + azure_endpoint=entry.data[CONF_API_BASE], |
| 94 | + api_version=entry.data[CONF_API_VERSION], |
| 95 | + api_key=entry.data[CONF_API_KEY], |
| 96 | + http_client=get_async_client(hass), |
| 97 | + ) |
138 | 98 |
|
139 | | - _LOGGER.debug("Response %s", result) |
140 | | - response = result["choices"][0]["message"] |
141 | | - messages.append(response) |
142 | | - self.history[conversation_id] = messages |
| 99 | + # Cache current platform data which gets added to each request (caching done by library) |
| 100 | + _ = await hass.async_add_executor_job(client.platform_headers) |
143 | 101 |
|
144 | | - intent_response = intent.IntentResponse(language=user_input.language) |
145 | | - intent_response.async_set_speech(response["content"]) |
146 | | - return conversation.ConversationResult( |
147 | | - response=intent_response, conversation_id=conversation_id |
148 | | - ) |
| 102 | + try: |
| 103 | + await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list) |
| 104 | + except openai.AuthenticationError as err: |
| 105 | + LOGGER.error("Invalid API key: %s", err) |
| 106 | + return False |
| 107 | + except openai.OpenAIError as err: |
| 108 | + raise ConfigEntryNotReady(err) from err |
149 | 109 |
|
150 | | - def _async_generate_prompt(self, raw_prompt: str) -> str: |
151 | | - """Generate a prompt for the user.""" |
152 | | - return template.Template(raw_prompt, self.hass).async_render( |
153 | | - { |
154 | | - "ha_name": self.hass.config.location_name, |
155 | | - }, |
156 | | - parse_result=False, |
157 | | - ) |
| 110 | + entry.runtime_data = client |
| 111 | + |
| 112 | + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) |
| 113 | + |
| 114 | + return True |
| 115 | + |
| 116 | + |
| 117 | +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: |
| 118 | + """Unload Azure OpenAI.""" |
| 119 | + |
| 120 | + """Unload OpenAI.""" |
| 121 | + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) |
0 commit comments