1
1
"""The OpenAI Conversation integration."""
2
2
from __future__ import annotations
3
3
4
+ import json
4
5
import logging
5
6
from typing import Literal
6
- import json
7
- import yaml
8
7
9
- from openai import AsyncOpenAI , AsyncAzureOpenAI
8
+ from openai import AsyncAzureOpenAI , AsyncOpenAI
9
+ from openai ._exceptions import AuthenticationError , OpenAIError
10
10
from openai .types .chat .chat_completion import (
11
- Choice ,
12
11
ChatCompletion ,
13
12
ChatCompletionMessage ,
13
+ Choice ,
14
14
)
15
- from openai . _exceptions import OpenAIError , AuthenticationError
15
+ import yaml
16
16
17
17
from homeassistant .components import conversation
18
+ from homeassistant .components .homeassistant .exposed_entities import async_should_expose
18
19
from homeassistant .config_entries import ConfigEntry
19
- from homeassistant .const import CONF_API_KEY , MATCH_ALL , ATTR_NAME
20
+ from homeassistant .const import ATTR_NAME , CONF_API_KEY , MATCH_ALL
20
21
from homeassistant .core import HomeAssistant
21
- from homeassistant .helpers .typing import ConfigType
22
- from homeassistant .util import ulid
23
- from homeassistant .components .homeassistant .exposed_entities import async_should_expose
24
22
from homeassistant .exceptions import (
25
23
ConfigEntryNotReady ,
26
24
HomeAssistantError ,
27
25
TemplateError ,
28
26
)
29
-
30
27
from homeassistant .helpers import (
31
28
config_validation as cv ,
29
+ entity_registry as er ,
32
30
intent ,
33
31
template ,
34
- entity_registry as er ,
35
32
)
33
+ from homeassistant .helpers .typing import ConfigType
34
+ from homeassistant .util import ulid
36
35
37
36
from .const import (
37
+ CONF_API_VERSION ,
38
38
CONF_ATTACH_USERNAME ,
39
+ CONF_BASE_URL ,
39
40
CONF_CHAT_MODEL ,
41
+ CONF_CONTEXT_THRESHOLD ,
42
+ CONF_CONTEXT_TRUNCATE_STRATEGY ,
43
+ CONF_FUNCTIONS ,
44
+ CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
40
45
CONF_MAX_TOKENS ,
46
+ CONF_ORGANIZATION ,
41
47
CONF_PROMPT ,
48
+ CONF_SKIP_AUTHENTICATION ,
42
49
CONF_TEMPERATURE ,
43
50
CONF_TOP_P ,
44
- CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
45
- CONF_FUNCTIONS ,
46
- CONF_BASE_URL ,
47
- CONF_API_VERSION ,
48
- CONF_SKIP_AUTHENTICATION ,
49
51
CONF_USE_TOOLS ,
50
- CONF_CONTEXT_THRESHOLD ,
51
- CONF_CONTEXT_TRUNCATE_STRATEGY ,
52
52
DEFAULT_ATTACH_USERNAME ,
53
53
DEFAULT_CHAT_MODEL ,
54
+ DEFAULT_CONF_FUNCTIONS ,
55
+ DEFAULT_CONTEXT_THRESHOLD ,
56
+ DEFAULT_CONTEXT_TRUNCATE_STRATEGY ,
57
+ DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
54
58
DEFAULT_MAX_TOKENS ,
55
59
DEFAULT_PROMPT ,
60
+ DEFAULT_SKIP_AUTHENTICATION ,
56
61
DEFAULT_TEMPERATURE ,
57
62
DEFAULT_TOP_P ,
58
- DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
59
- DEFAULT_CONF_FUNCTIONS ,
60
- DEFAULT_SKIP_AUTHENTICATION ,
61
63
DEFAULT_USE_TOOLS ,
62
- DEFAULT_CONTEXT_THRESHOLD ,
63
- DEFAULT_CONTEXT_TRUNCATE_STRATEGY ,
64
64
DOMAIN ,
65
+ EVENT_CONVERSATION_FINISHED ,
65
66
)
66
-
67
67
from .exceptions import (
68
- FunctionNotFound ,
69
68
FunctionLoadFailed ,
70
- ParseArgumentsFailed ,
69
+ FunctionNotFound ,
71
70
InvalidFunction ,
71
+ ParseArgumentsFailed ,
72
+ TokenLengthExceededError ,
72
73
)
73
-
74
74
from .helpers import (
75
- validate_authentication ,
76
75
get_function_executor ,
77
76
is_azure ,
77
+ validate_authentication ,
78
78
)
79
-
80
79
from .services import async_setup_services
81
80
82
-
83
81
_LOGGER = logging .getLogger (__name__ )
84
82
85
83
CONFIG_SCHEMA = cv .config_entry_only_config_schema (DOMAIN )
@@ -104,6 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
104
102
api_key = entry .data [CONF_API_KEY ],
105
103
base_url = entry .data .get (CONF_BASE_URL ),
106
104
api_version = entry .data .get (CONF_API_VERSION ),
105
+ organization = entry .data .get (CONF_ORGANIZATION ),
107
106
skip_authentication = entry .data .get (
108
107
CONF_SKIP_AUTHENTICATION , DEFAULT_SKIP_AUTHENTICATION
109
108
),
@@ -145,10 +144,13 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
145
144
api_key = entry .data [CONF_API_KEY ],
146
145
azure_endpoint = base_url ,
147
146
api_version = entry .data .get (CONF_API_VERSION ),
147
+ organization = entry .data .get (CONF_ORGANIZATION ),
148
148
)
149
149
else :
150
150
self .client = AsyncOpenAI (
151
- api_key = entry .data [CONF_API_KEY ], base_url = base_url
151
+ api_key = entry .data [CONF_API_KEY ],
152
+ base_url = base_url ,
153
+ organization = entry .data .get (CONF_ORGANIZATION ),
152
154
)
153
155
154
156
@property
@@ -191,7 +193,7 @@ async def async_process(
191
193
messages .append (user_message )
192
194
193
195
try :
194
- response = await self .query (user_input , messages , exposed_entities , 0 )
196
+ query_response = await self .query (user_input , messages , exposed_entities , 0 )
195
197
except OpenAIError as err :
196
198
_LOGGER .error (err )
197
199
intent_response = intent .IntentResponse (language = user_input .language )
@@ -213,11 +215,20 @@ async def async_process(
213
215
response = intent_response , conversation_id = conversation_id
214
216
)
215
217
216
- messages .append (response .model_dump (exclude_none = True ))
218
+ messages .append (query_response . message .model_dump (exclude_none = True ))
217
219
self .history [conversation_id ] = messages
218
220
221
+ self .hass .bus .async_fire (
222
+ EVENT_CONVERSATION_FINISHED ,
223
+ {
224
+ "response" : query_response .response .model_dump (),
225
+ "user_input" : user_input ,
226
+ "messages" : messages ,
227
+ },
228
+ )
229
+
219
230
intent_response = intent .IntentResponse (language = user_input .language )
220
- intent_response .async_set_speech (response .content )
231
+ intent_response .async_set_speech (query_response . message .content )
221
232
return conversation .ConversationResult (
222
233
response = intent_response , conversation_id = conversation_id
223
234
)
@@ -317,7 +328,7 @@ async def query(
317
328
messages ,
318
329
exposed_entities ,
319
330
n_requests ,
320
- ):
331
+ ) -> OpenAIQueryResponse :
321
332
"""Process a sentence."""
322
333
model = self .entry .options .get (CONF_CHAT_MODEL , DEFAULT_CHAT_MODEL )
323
334
max_tokens = self .entry .options .get (CONF_MAX_TOKENS , DEFAULT_MAX_TOKENS )
@@ -366,14 +377,17 @@ async def query(
366
377
message = choice .message
367
378
368
379
if choice .finish_reason == "function_call" :
369
- message = await self .execute_function_call (
380
+ return await self .execute_function_call (
370
381
user_input , messages , message , exposed_entities , n_requests + 1
371
382
)
372
383
if choice .finish_reason == "tool_calls" :
373
- message = await self .execute_tool_calls (
384
+ return await self .execute_tool_calls (
374
385
user_input , messages , message , exposed_entities , n_requests + 1
375
386
)
376
- return message
387
+ if choice .finish_reason == "length" :
388
+ raise TokenLengthExceededError (response .usage .completion_tokens )
389
+
390
+ return OpenAIQueryResponse (response = response , message = message )
377
391
378
392
async def execute_function_call (
379
393
self ,
@@ -382,7 +396,7 @@ async def execute_function_call(
382
396
message : ChatCompletionMessage ,
383
397
exposed_entities ,
384
398
n_requests ,
385
- ):
399
+ ) -> OpenAIQueryResponse :
386
400
function_name = message .function_call .name
387
401
function = next (
388
402
(s for s in self .get_functions () if s ["spec" ]["name" ] == function_name ),
@@ -407,7 +421,7 @@ async def execute_function(
407
421
exposed_entities ,
408
422
n_requests ,
409
423
function ,
410
- ):
424
+ ) -> OpenAIQueryResponse :
411
425
function_executor = get_function_executor (function ["function" ]["type" ])
412
426
413
427
try :
@@ -435,7 +449,7 @@ async def execute_tool_calls(
435
449
message : ChatCompletionMessage ,
436
450
exposed_entities ,
437
451
n_requests ,
438
- ):
452
+ ) -> OpenAIQueryResponse :
439
453
messages .append (message .model_dump (exclude_none = True ))
440
454
for tool in message .tool_calls :
441
455
function_name = tool .function .name
@@ -469,7 +483,7 @@ async def execute_tool_function(
469
483
tool ,
470
484
exposed_entities ,
471
485
function ,
472
- ):
486
+ ) -> OpenAIQueryResponse :
473
487
function_executor = get_function_executor (function ["function" ]["type" ])
474
488
475
489
try :
@@ -481,3 +495,14 @@ async def execute_tool_function(
481
495
self .hass , function ["function" ], arguments , user_input , exposed_entities
482
496
)
483
497
return result
498
+
499
+
500
+ class OpenAIQueryResponse :
501
+ """OpenAI query response value object."""
502
+
503
+ def __init__ (
504
+ self , response : ChatCompletion , message : ChatCompletionMessage
505
+ ) -> None :
506
+ """Initialize OpenAI query response value object."""
507
+ self .response = response
508
+ self .message = message
0 commit comments