Skip to content

Commit f54662d

Browse files
authored
Merge pull request #137 from jekalmin/v1.0.2
1.0.2
2 parents 3697240 + 43e465e commit f54662d

File tree

16 files changed

+394
-52
lines changed

16 files changed

+394
-52
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Extended OpenAI Conversation uses OpenAI API's feature of [function calling](htt
1616
Since "gpt-3.5-turbo" model already knows how to call service of Home Assistant in general, you just have to let model know what devices you have by [exposing entities](https://github.com/jekalmin/extended_openai_conversation#preparation)
1717

1818
## Installation
19-
1. Install via HACS or by copying `extended_openai_conversation` folder into `<config directory>/custom_components`
19+
1. Install via registering as a custom repository of HACS or by copying `extended_openai_conversation` folder into `<config directory>/custom_components`
2020
2. Restart Home Assistant
2121
3. Go to Settings > Devices & Services.
2222
4. In the bottom right corner, select the Add Integration button.

custom_components/extended_openai_conversation/__init__.py

+125-10
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
CONF_BASE_URL,
4747
CONF_API_VERSION,
4848
CONF_SKIP_AUTHENTICATION,
49+
CONF_USE_TOOLS,
50+
CONF_CONTEXT_THRESHOLD,
51+
CONF_CONTEXT_TRUNCATE_STRATEGY,
4952
DEFAULT_ATTACH_USERNAME,
5053
DEFAULT_CHAT_MODEL,
5154
DEFAULT_MAX_TOKENS,
@@ -55,6 +58,9 @@
5558
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
5659
DEFAULT_CONF_FUNCTIONS,
5760
DEFAULT_SKIP_AUTHENTICATION,
61+
DEFAULT_USE_TOOLS,
62+
DEFAULT_CONTEXT_THRESHOLD,
63+
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
5864
DOMAIN,
5965
)
6066

@@ -153,7 +159,6 @@ def supported_languages(self) -> list[str] | Literal["*"]:
153159
async def async_process(
154160
self, user_input: conversation.ConversationInput
155161
) -> conversation.ConversationResult:
156-
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
157162
exposed_entities = self.get_exposed_entities()
158163

159164
if user_input.conversation_id in self.history:
@@ -163,7 +168,9 @@ async def async_process(
163168
conversation_id = ulid.ulid()
164169
user_input.conversation_id = conversation_id
165170
try:
166-
prompt = self._async_generate_prompt(raw_prompt, exposed_entities)
171+
system_message = self._generate_system_message(
172+
exposed_entities, user_input
173+
)
167174
except TemplateError as err:
168175
_LOGGER.error("Error rendering prompt: %s", err)
169176
intent_response = intent.IntentResponse(language=user_input.language)
@@ -174,7 +181,7 @@ async def async_process(
174181
return conversation.ConversationResult(
175182
response=intent_response, conversation_id=conversation_id
176183
)
177-
messages = [{"role": "system", "content": prompt}]
184+
messages = [system_message]
178185
user_message = {"role": "user", "content": user_input.text}
179186
if self.entry.options.get(CONF_ATTACH_USERNAME, DEFAULT_ATTACH_USERNAME):
180187
user = await self.hass.auth.async_get_user(user_input.context.user_id)
@@ -215,12 +222,25 @@ async def async_process(
215222
response=intent_response, conversation_id=conversation_id
216223
)
217224

218-
def _async_generate_prompt(self, raw_prompt: str, exposed_entities) -> str:
225+
def _generate_system_message(
226+
self, exposed_entities, user_input: conversation.ConversationInput
227+
):
228+
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
229+
prompt = self._async_generate_prompt(raw_prompt, exposed_entities, user_input)
230+
return {"role": "system", "content": prompt}
231+
232+
def _async_generate_prompt(
233+
self,
234+
raw_prompt: str,
235+
exposed_entities,
236+
user_input: conversation.ConversationInput,
237+
) -> str:
219238
"""Generate a prompt for the user."""
220239
return template.Template(raw_prompt, self.hass).async_render(
221240
{
222241
"ha_name": self.hass.config.location_name,
223242
"exposed_entities": exposed_entities,
243+
"current_device_id": user_input.device_id,
224244
},
225245
parse_result=False,
226246
)
@@ -269,6 +289,28 @@ def get_functions(self):
269289
except:
270290
raise FunctionLoadFailed()
271291

292+
async def truncate_message_history(
293+
self, messages, exposed_entities, user_input: conversation.ConversationInput
294+
):
295+
"""Truncate message history."""
296+
strategy = self.entry.options.get(
297+
CONF_CONTEXT_TRUNCATE_STRATEGY, DEFAULT_CONTEXT_TRUNCATE_STRATEGY
298+
)
299+
300+
if strategy == "clear":
301+
last_user_message_index = None
302+
for i in reversed(range(len(messages))):
303+
if messages[i]["role"] == "user":
304+
last_user_message_index = i
305+
break
306+
307+
if last_user_message_index is not None:
308+
del messages[1:last_user_message_index]
309+
# refresh system prompt when all messages are deleted
310+
messages[0] = self._generate_system_message(
311+
exposed_entities, user_input
312+
)
313+
272314
async def query(
273315
self,
274316
user_input: conversation.ConversationInput,
@@ -281,16 +323,27 @@ async def query(
281323
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
282324
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
283325
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
326+
use_tools = self.entry.options.get(CONF_USE_TOOLS, DEFAULT_USE_TOOLS)
327+
context_threshold = self.entry.options.get(
328+
CONF_CONTEXT_THRESHOLD, DEFAULT_CONTEXT_THRESHOLD
329+
)
284330
functions = list(map(lambda s: s["spec"], self.get_functions()))
285331
function_call = "auto"
286332
if n_requests == self.entry.options.get(
287333
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
288334
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
289335
):
290336
function_call = "none"
337+
338+
tool_kwargs = {"functions": functions, "function_call": function_call}
339+
if use_tools:
340+
tool_kwargs = {
341+
"tools": [{"type": "function", "function": func} for func in functions],
342+
"tool_choice": function_call,
343+
}
344+
291345
if len(functions) == 0:
292-
functions = None
293-
function_call = None
346+
tool_kwargs = {}
294347

295348
_LOGGER.info("Prompt for %s: %s", model, messages)
296349

@@ -301,20 +354,28 @@ async def query(
301354
top_p=top_p,
302355
temperature=temperature,
303356
user=user_input.conversation_id,
304-
functions=functions,
305-
function_call=function_call,
357+
**tool_kwargs,
306358
)
307359

308360
_LOGGER.info("Response %s", response.model_dump(exclude_none=True))
361+
362+
if response.usage.total_tokens > context_threshold:
363+
await self.truncate_message_history(messages, exposed_entities, user_input)
364+
309365
choice: Choice = response.choices[0]
310366
message = choice.message
367+
311368
if choice.finish_reason == "function_call":
312369
message = await self.execute_function_call(
313370
user_input, messages, message, exposed_entities, n_requests + 1
314371
)
372+
if choice.finish_reason == "tool_calls":
373+
message = await self.execute_tool_calls(
374+
user_input, messages, message, exposed_entities, n_requests + 1
375+
)
315376
return message
316377

317-
def execute_function_call(
378+
async def execute_function_call(
318379
self,
319380
user_input: conversation.ConversationInput,
320381
messages,
@@ -328,7 +389,7 @@ def execute_function_call(
328389
None,
329390
)
330391
if function is not None:
331-
return self.execute_function(
392+
return await self.execute_function(
332393
user_input,
333394
messages,
334395
message,
@@ -366,3 +427,57 @@ async def execute_function(
366427
}
367428
)
368429
return await self.query(user_input, messages, exposed_entities, n_requests)
430+
431+
async def execute_tool_calls(
432+
self,
433+
user_input: conversation.ConversationInput,
434+
messages,
435+
message: ChatCompletionMessage,
436+
exposed_entities,
437+
n_requests,
438+
):
439+
messages.append(message.model_dump(exclude_none=True))
440+
for tool in message.tool_calls:
441+
function_name = tool.function.name
442+
function = next(
443+
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
444+
None,
445+
)
446+
if function is not None:
447+
result = await self.execute_tool_function(
448+
user_input,
449+
tool,
450+
exposed_entities,
451+
function,
452+
)
453+
454+
messages.append(
455+
{
456+
"tool_call_id": tool.id,
457+
"role": "tool",
458+
"name": function_name,
459+
"content": str(result),
460+
}
461+
)
462+
else:
463+
raise FunctionNotFound(function_name)
464+
return await self.query(user_input, messages, exposed_entities, n_requests)
465+
466+
async def execute_tool_function(
467+
self,
468+
user_input: conversation.ConversationInput,
469+
tool,
470+
exposed_entities,
471+
function,
472+
):
473+
function_executor = get_function_executor(function["function"]["type"])
474+
475+
try:
476+
arguments = json.loads(tool.function.arguments)
477+
except json.decoder.JSONDecodeError as err:
478+
raise ParseArgumentsFailed(tool.function.arguments) from err
479+
480+
result = await function_executor.execute(
481+
self.hass, function["function"], arguments, user_input, exposed_entities
482+
)
483+
return result

custom_components/extended_openai_conversation/config_flow.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
NumberSelector,
2020
NumberSelectorConfig,
2121
TemplateSelector,
22-
AttributeSelector,
2322
SelectSelector,
2423
SelectSelectorConfig,
2524
SelectOptionDict,
@@ -40,6 +39,9 @@
4039
CONF_BASE_URL,
4140
CONF_API_VERSION,
4241
CONF_SKIP_AUTHENTICATION,
42+
CONF_USE_TOOLS,
43+
CONF_CONTEXT_THRESHOLD,
44+
CONF_CONTEXT_TRUNCATE_STRATEGY,
4345
DEFAULT_ATTACH_USERNAME,
4446
DEFAULT_CHAT_MODEL,
4547
DEFAULT_MAX_TOKENS,
@@ -50,6 +52,10 @@
5052
DEFAULT_CONF_FUNCTIONS,
5153
DEFAULT_CONF_BASE_URL,
5254
DEFAULT_SKIP_AUTHENTICATION,
55+
DEFAULT_USE_TOOLS,
56+
DEFAULT_CONTEXT_THRESHOLD,
57+
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
58+
CONTEXT_TRUNCATE_STRATEGIES,
5359
DOMAIN,
5460
DEFAULT_NAME,
5561
)
@@ -80,6 +86,9 @@
8086
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
8187
CONF_FUNCTIONS: DEFAULT_CONF_FUNCTIONS_STR,
8288
CONF_ATTACH_USERNAME: DEFAULT_ATTACH_USERNAME,
89+
CONF_USE_TOOLS: DEFAULT_USE_TOOLS,
90+
CONF_CONTEXT_THRESHOLD: DEFAULT_CONTEXT_THRESHOLD,
91+
CONF_CONTEXT_TRUNCATE_STRATEGY: DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
8392
}
8493
)
8594

@@ -222,4 +231,29 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
222231
description={"suggested_value": options.get(CONF_ATTACH_USERNAME)},
223232
default=DEFAULT_ATTACH_USERNAME,
224233
): BooleanSelector(),
234+
vol.Optional(
235+
CONF_USE_TOOLS,
236+
description={"suggested_value": options.get(CONF_USE_TOOLS)},
237+
default=DEFAULT_USE_TOOLS,
238+
): BooleanSelector(),
239+
vol.Optional(
240+
CONF_CONTEXT_THRESHOLD,
241+
description={"suggested_value": options.get(CONF_CONTEXT_THRESHOLD)},
242+
default=DEFAULT_CONTEXT_THRESHOLD,
243+
): int,
244+
vol.Optional(
245+
CONF_CONTEXT_TRUNCATE_STRATEGY,
246+
description={
247+
"suggested_value": options.get(CONF_CONTEXT_TRUNCATE_STRATEGY)
248+
},
249+
default=DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
250+
): SelectSelector(
251+
SelectSelectorConfig(
252+
options=[
253+
SelectOptionDict(value=strategy["key"], label=strategy["label"])
254+
for strategy in CONTEXT_TRUNCATE_STRATEGIES
255+
],
256+
mode=SelectSelectorMode.DROPDOWN,
257+
)
258+
),
225259
}

custom_components/extended_openai_conversation/const.py

+7
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,12 @@
8484
]
8585
CONF_ATTACH_USERNAME = "attach_username"
8686
DEFAULT_ATTACH_USERNAME = False
87+
CONF_USE_TOOLS = "use_tools"
88+
DEFAULT_USE_TOOLS = False
89+
CONF_CONTEXT_THRESHOLD = "context_threshold"
90+
DEFAULT_CONTEXT_THRESHOLD = 13000
91+
CONTEXT_TRUNCATE_STRATEGIES = [{"key": "clear", "label": "Clear All Messages"}]
92+
CONF_CONTEXT_TRUNCATE_STRATEGY = "context_truncate_strategy"
93+
DEFAULT_CONTEXT_TRUNCATE_STRATEGY = CONTEXT_TRUNCATE_STRATEGIES[0]["key"]
8794

8895
SERVICE_QUERY_IMAGE = "query_image"

0 commit comments

Comments
 (0)