Skip to content

Commit 45be8f9

Browse files
authored
Merge pull request #153 from jekalmin/v1.0.3
1.0.3
2 parents f54662d + 690125d commit 45be8f9

File tree

17 files changed

+359
-138
lines changed

17 files changed

+359
-138
lines changed

custom_components/extended_openai_conversation/__init__.py

+67-42
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,83 @@
11
"""The OpenAI Conversation integration."""
22
from __future__ import annotations
33

4+
import json
45
import logging
56
from typing import Literal
6-
import json
7-
import yaml
87

9-
from openai import AsyncOpenAI, AsyncAzureOpenAI
8+
from openai import AsyncAzureOpenAI, AsyncOpenAI
9+
from openai._exceptions import AuthenticationError, OpenAIError
1010
from openai.types.chat.chat_completion import (
11-
Choice,
1211
ChatCompletion,
1312
ChatCompletionMessage,
13+
Choice,
1414
)
15-
from openai._exceptions import OpenAIError, AuthenticationError
15+
import yaml
1616

1717
from homeassistant.components import conversation
18+
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
1819
from homeassistant.config_entries import ConfigEntry
19-
from homeassistant.const import CONF_API_KEY, MATCH_ALL, ATTR_NAME
20+
from homeassistant.const import ATTR_NAME, CONF_API_KEY, MATCH_ALL
2021
from homeassistant.core import HomeAssistant
21-
from homeassistant.helpers.typing import ConfigType
22-
from homeassistant.util import ulid
23-
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
2422
from homeassistant.exceptions import (
2523
ConfigEntryNotReady,
2624
HomeAssistantError,
2725
TemplateError,
2826
)
29-
3027
from homeassistant.helpers import (
3128
config_validation as cv,
29+
entity_registry as er,
3230
intent,
3331
template,
34-
entity_registry as er,
3532
)
33+
from homeassistant.helpers.typing import ConfigType
34+
from homeassistant.util import ulid
3635

3736
from .const import (
37+
CONF_API_VERSION,
3838
CONF_ATTACH_USERNAME,
39+
CONF_BASE_URL,
3940
CONF_CHAT_MODEL,
41+
CONF_CONTEXT_THRESHOLD,
42+
CONF_CONTEXT_TRUNCATE_STRATEGY,
43+
CONF_FUNCTIONS,
44+
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
4045
CONF_MAX_TOKENS,
46+
CONF_ORGANIZATION,
4147
CONF_PROMPT,
48+
CONF_SKIP_AUTHENTICATION,
4249
CONF_TEMPERATURE,
4350
CONF_TOP_P,
44-
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
45-
CONF_FUNCTIONS,
46-
CONF_BASE_URL,
47-
CONF_API_VERSION,
48-
CONF_SKIP_AUTHENTICATION,
4951
CONF_USE_TOOLS,
50-
CONF_CONTEXT_THRESHOLD,
51-
CONF_CONTEXT_TRUNCATE_STRATEGY,
5252
DEFAULT_ATTACH_USERNAME,
5353
DEFAULT_CHAT_MODEL,
54+
DEFAULT_CONF_FUNCTIONS,
55+
DEFAULT_CONTEXT_THRESHOLD,
56+
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
57+
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
5458
DEFAULT_MAX_TOKENS,
5559
DEFAULT_PROMPT,
60+
DEFAULT_SKIP_AUTHENTICATION,
5661
DEFAULT_TEMPERATURE,
5762
DEFAULT_TOP_P,
58-
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
59-
DEFAULT_CONF_FUNCTIONS,
60-
DEFAULT_SKIP_AUTHENTICATION,
6163
DEFAULT_USE_TOOLS,
62-
DEFAULT_CONTEXT_THRESHOLD,
63-
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
6464
DOMAIN,
65+
EVENT_CONVERSATION_FINISHED,
6566
)
66-
6767
from .exceptions import (
68-
FunctionNotFound,
6968
FunctionLoadFailed,
70-
ParseArgumentsFailed,
69+
FunctionNotFound,
7170
InvalidFunction,
71+
ParseArgumentsFailed,
72+
TokenLengthExceededError,
7273
)
73-
7474
from .helpers import (
75-
validate_authentication,
7675
get_function_executor,
7776
is_azure,
77+
validate_authentication,
7878
)
79-
8079
from .services import async_setup_services
8180

82-
8381
_LOGGER = logging.getLogger(__name__)
8482

8583
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@@ -104,6 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
104102
api_key=entry.data[CONF_API_KEY],
105103
base_url=entry.data.get(CONF_BASE_URL),
106104
api_version=entry.data.get(CONF_API_VERSION),
105+
organization=entry.data.get(CONF_ORGANIZATION),
107106
skip_authentication=entry.data.get(
108107
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
109108
),
@@ -145,10 +144,13 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
145144
api_key=entry.data[CONF_API_KEY],
146145
azure_endpoint=base_url,
147146
api_version=entry.data.get(CONF_API_VERSION),
147+
organization=entry.data.get(CONF_ORGANIZATION),
148148
)
149149
else:
150150
self.client = AsyncOpenAI(
151-
api_key=entry.data[CONF_API_KEY], base_url=base_url
151+
api_key=entry.data[CONF_API_KEY],
152+
base_url=base_url,
153+
organization=entry.data.get(CONF_ORGANIZATION),
152154
)
153155

154156
@property
@@ -191,7 +193,7 @@ async def async_process(
191193
messages.append(user_message)
192194

193195
try:
194-
response = await self.query(user_input, messages, exposed_entities, 0)
196+
query_response = await self.query(user_input, messages, exposed_entities, 0)
195197
except OpenAIError as err:
196198
_LOGGER.error(err)
197199
intent_response = intent.IntentResponse(language=user_input.language)
@@ -213,11 +215,20 @@ async def async_process(
213215
response=intent_response, conversation_id=conversation_id
214216
)
215217

216-
messages.append(response.model_dump(exclude_none=True))
218+
messages.append(query_response.message.model_dump(exclude_none=True))
217219
self.history[conversation_id] = messages
218220

221+
self.hass.bus.async_fire(
222+
EVENT_CONVERSATION_FINISHED,
223+
{
224+
"response": query_response.response.model_dump(),
225+
"user_input": user_input,
226+
"messages": messages,
227+
},
228+
)
229+
219230
intent_response = intent.IntentResponse(language=user_input.language)
220-
intent_response.async_set_speech(response.content)
231+
intent_response.async_set_speech(query_response.message.content)
221232
return conversation.ConversationResult(
222233
response=intent_response, conversation_id=conversation_id
223234
)
@@ -317,7 +328,7 @@ async def query(
317328
messages,
318329
exposed_entities,
319330
n_requests,
320-
):
331+
) -> OpenAIQueryResponse:
321332
"""Process a sentence."""
322333
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
323334
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -366,14 +377,17 @@ async def query(
366377
message = choice.message
367378

368379
if choice.finish_reason == "function_call":
369-
message = await self.execute_function_call(
380+
return await self.execute_function_call(
370381
user_input, messages, message, exposed_entities, n_requests + 1
371382
)
372383
if choice.finish_reason == "tool_calls":
373-
message = await self.execute_tool_calls(
384+
return await self.execute_tool_calls(
374385
user_input, messages, message, exposed_entities, n_requests + 1
375386
)
376-
return message
387+
if choice.finish_reason == "length":
388+
raise TokenLengthExceededError(response.usage.completion_tokens)
389+
390+
return OpenAIQueryResponse(response=response, message=message)
377391

378392
async def execute_function_call(
379393
self,
@@ -382,7 +396,7 @@ async def execute_function_call(
382396
message: ChatCompletionMessage,
383397
exposed_entities,
384398
n_requests,
385-
):
399+
) -> OpenAIQueryResponse:
386400
function_name = message.function_call.name
387401
function = next(
388402
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
@@ -407,7 +421,7 @@ async def execute_function(
407421
exposed_entities,
408422
n_requests,
409423
function,
410-
):
424+
) -> OpenAIQueryResponse:
411425
function_executor = get_function_executor(function["function"]["type"])
412426

413427
try:
@@ -435,7 +449,7 @@ async def execute_tool_calls(
435449
message: ChatCompletionMessage,
436450
exposed_entities,
437451
n_requests,
438-
):
452+
) -> OpenAIQueryResponse:
439453
messages.append(message.model_dump(exclude_none=True))
440454
for tool in message.tool_calls:
441455
function_name = tool.function.name
@@ -469,7 +483,7 @@ async def execute_tool_function(
469483
tool,
470484
exposed_entities,
471485
function,
472-
):
486+
) -> OpenAIQueryResponse:
473487
function_executor = get_function_executor(function["function"]["type"])
474488

475489
try:
@@ -481,3 +495,14 @@ async def execute_tool_function(
481495
self.hass, function["function"], arguments, user_input, exposed_entities
482496
)
483497
return result
498+
499+
500+
class OpenAIQueryResponse:
501+
"""OpenAI query response value object."""
502+
503+
def __init__(
504+
self, response: ChatCompletion, message: ChatCompletionMessage
505+
) -> None:
506+
"""Initialize OpenAI query response value object."""
507+
self.response = response
508+
self.message = message

custom_components/extended_openai_conversation/config_flow.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,62 @@
33

44
import logging
55
import types
6-
import yaml
76
from types import MappingProxyType
87
from typing import Any
98

109
from openai._exceptions import APIConnectionError, AuthenticationError
1110
import voluptuous as vol
11+
import yaml
1212

1313
from homeassistant import config_entries
14-
from homeassistant.const import CONF_NAME, CONF_API_KEY
14+
from homeassistant.const import CONF_API_KEY, CONF_NAME
1515
from homeassistant.core import HomeAssistant
1616
from homeassistant.data_entry_flow import FlowResult
1717
from homeassistant.helpers.selector import (
1818
BooleanSelector,
1919
NumberSelector,
2020
NumberSelectorConfig,
21-
TemplateSelector,
21+
SelectOptionDict,
2222
SelectSelector,
2323
SelectSelectorConfig,
24-
SelectOptionDict,
2524
SelectSelectorMode,
25+
TemplateSelector,
2626
)
2727

28-
from .helpers import validate_authentication
29-
3028
from .const import (
29+
CONF_API_VERSION,
3130
CONF_ATTACH_USERNAME,
31+
CONF_BASE_URL,
3232
CONF_CHAT_MODEL,
33+
CONF_CONTEXT_THRESHOLD,
34+
CONF_CONTEXT_TRUNCATE_STRATEGY,
35+
CONF_FUNCTIONS,
36+
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
3337
CONF_MAX_TOKENS,
38+
CONF_ORGANIZATION,
3439
CONF_PROMPT,
40+
CONF_SKIP_AUTHENTICATION,
3541
CONF_TEMPERATURE,
3642
CONF_TOP_P,
37-
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
38-
CONF_FUNCTIONS,
39-
CONF_BASE_URL,
40-
CONF_API_VERSION,
41-
CONF_SKIP_AUTHENTICATION,
4243
CONF_USE_TOOLS,
43-
CONF_CONTEXT_THRESHOLD,
44-
CONF_CONTEXT_TRUNCATE_STRATEGY,
44+
CONTEXT_TRUNCATE_STRATEGIES,
4545
DEFAULT_ATTACH_USERNAME,
4646
DEFAULT_CHAT_MODEL,
47+
DEFAULT_CONF_BASE_URL,
48+
DEFAULT_CONF_FUNCTIONS,
49+
DEFAULT_CONTEXT_THRESHOLD,
50+
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
51+
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
4752
DEFAULT_MAX_TOKENS,
53+
DEFAULT_NAME,
4854
DEFAULT_PROMPT,
55+
DEFAULT_SKIP_AUTHENTICATION,
4956
DEFAULT_TEMPERATURE,
5057
DEFAULT_TOP_P,
51-
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
52-
DEFAULT_CONF_FUNCTIONS,
53-
DEFAULT_CONF_BASE_URL,
54-
DEFAULT_SKIP_AUTHENTICATION,
5558
DEFAULT_USE_TOOLS,
56-
DEFAULT_CONTEXT_THRESHOLD,
57-
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
58-
CONTEXT_TRUNCATE_STRATEGIES,
5959
DOMAIN,
60-
DEFAULT_NAME,
6160
)
61+
from .helpers import validate_authentication
6262

6363
_LOGGER = logging.getLogger(__name__)
6464

@@ -68,6 +68,7 @@
6868
vol.Required(CONF_API_KEY): str,
6969
vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str,
7070
vol.Optional(CONF_API_VERSION): str,
71+
vol.Optional(CONF_ORGANIZATION): str,
7172
vol.Optional(
7273
CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION
7374
): bool,
@@ -101,6 +102,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
101102
api_key = data[CONF_API_KEY]
102103
base_url = data.get(CONF_BASE_URL)
103104
api_version = data.get(CONF_API_VERSION)
105+
organization = data.get(CONF_ORGANIZATION)
104106
skip_authentication = data.get(CONF_SKIP_AUTHENTICATION)
105107

106108
if base_url == DEFAULT_CONF_BASE_URL:
@@ -113,6 +115,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
113115
api_key=api_key,
114116
base_url=base_url,
115117
api_version=api_version,
118+
organization=organization,
116119
skip_authentication=skip_authentication,
117120
)
118121

custom_components/extended_openai_conversation/const.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
DOMAIN = "extended_openai_conversation"
44
DEFAULT_NAME = "Extended OpenAI Conversation"
5+
CONF_ORGANIZATION = "organization"
56
CONF_BASE_URL = "base_url"
67
DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1"
78
CONF_API_VERSION = "api_version"
89
CONF_SKIP_AUTHENTICATION = "skip_authentication"
910
DEFAULT_SKIP_AUTHENTICATION = False
1011

1112
EVENT_AUTOMATION_REGISTERED = "automation_registered_via_extended_openai_conversation"
13+
EVENT_CONVERSATION_FINISHED = "extended_openai_conversation.conversation.finished"
1214

1315
CONF_PROMPT = "prompt"
1416
DEFAULT_PROMPT = """I want you to act as smart home manager of Home Assistant.
@@ -93,3 +95,5 @@
9395
DEFAULT_CONTEXT_TRUNCATE_STRATEGY = CONTEXT_TRUNCATE_STRATEGIES[0]["key"]
9496

9597
SERVICE_QUERY_IMAGE = "query_image"
98+
99+
CONF_PAYLOAD_TEMPLATE = "payload_template"

0 commit comments

Comments
 (0)