46
46
CONF_BASE_URL ,
47
47
CONF_API_VERSION ,
48
48
CONF_SKIP_AUTHENTICATION ,
49
+ CONF_USE_TOOLS ,
50
+ CONF_CONTEXT_THRESHOLD ,
51
+ CONF_CONTEXT_TRUNCATE_STRATEGY ,
49
52
DEFAULT_ATTACH_USERNAME ,
50
53
DEFAULT_CHAT_MODEL ,
51
54
DEFAULT_MAX_TOKENS ,
55
58
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
56
59
DEFAULT_CONF_FUNCTIONS ,
57
60
DEFAULT_SKIP_AUTHENTICATION ,
61
+ DEFAULT_USE_TOOLS ,
62
+ DEFAULT_CONTEXT_THRESHOLD ,
63
+ DEFAULT_CONTEXT_TRUNCATE_STRATEGY ,
58
64
DOMAIN ,
59
65
)
60
66
@@ -153,7 +159,6 @@ def supported_languages(self) -> list[str] | Literal["*"]:
153
159
async def async_process (
154
160
self , user_input : conversation .ConversationInput
155
161
) -> conversation .ConversationResult :
156
- raw_prompt = self .entry .options .get (CONF_PROMPT , DEFAULT_PROMPT )
157
162
exposed_entities = self .get_exposed_entities ()
158
163
159
164
if user_input .conversation_id in self .history :
@@ -163,7 +168,9 @@ async def async_process(
163
168
conversation_id = ulid .ulid ()
164
169
user_input .conversation_id = conversation_id
165
170
try :
166
- prompt = self ._async_generate_prompt (raw_prompt , exposed_entities )
171
+ system_message = self ._generate_system_message (
172
+ exposed_entities , user_input
173
+ )
167
174
except TemplateError as err :
168
175
_LOGGER .error ("Error rendering prompt: %s" , err )
169
176
intent_response = intent .IntentResponse (language = user_input .language )
@@ -174,7 +181,7 @@ async def async_process(
174
181
return conversation .ConversationResult (
175
182
response = intent_response , conversation_id = conversation_id
176
183
)
177
- messages = [{ "role" : "system" , "content" : prompt } ]
184
+ messages = [system_message ]
178
185
user_message = {"role" : "user" , "content" : user_input .text }
179
186
if self .entry .options .get (CONF_ATTACH_USERNAME , DEFAULT_ATTACH_USERNAME ):
180
187
user = await self .hass .auth .async_get_user (user_input .context .user_id )
@@ -215,12 +222,25 @@ async def async_process(
215
222
response = intent_response , conversation_id = conversation_id
216
223
)
217
224
218
- def _async_generate_prompt (self , raw_prompt : str , exposed_entities ) -> str :
225
+ def _generate_system_message (
226
+ self , exposed_entities , user_input : conversation .ConversationInput
227
+ ):
228
+ raw_prompt = self .entry .options .get (CONF_PROMPT , DEFAULT_PROMPT )
229
+ prompt = self ._async_generate_prompt (raw_prompt , exposed_entities , user_input )
230
+ return {"role" : "system" , "content" : prompt }
231
+
232
+ def _async_generate_prompt (
233
+ self ,
234
+ raw_prompt : str ,
235
+ exposed_entities ,
236
+ user_input : conversation .ConversationInput ,
237
+ ) -> str :
219
238
"""Generate a prompt for the user."""
220
239
return template .Template (raw_prompt , self .hass ).async_render (
221
240
{
222
241
"ha_name" : self .hass .config .location_name ,
223
242
"exposed_entities" : exposed_entities ,
243
+ "current_device_id" : user_input .device_id ,
224
244
},
225
245
parse_result = False ,
226
246
)
@@ -269,6 +289,28 @@ def get_functions(self):
269
289
except :
270
290
raise FunctionLoadFailed ()
271
291
292
+ async def truncate_message_history (
293
+ self , messages , exposed_entities , user_input : conversation .ConversationInput
294
+ ):
295
+ """Truncate message history."""
296
+ strategy = self .entry .options .get (
297
+ CONF_CONTEXT_TRUNCATE_STRATEGY , DEFAULT_CONTEXT_TRUNCATE_STRATEGY
298
+ )
299
+
300
+ if strategy == "clear" :
301
+ last_user_message_index = None
302
+ for i in reversed (range (len (messages ))):
303
+ if messages [i ]["role" ] == "user" :
304
+ last_user_message_index = i
305
+ break
306
+
307
+ if last_user_message_index is not None :
308
+ del messages [1 :last_user_message_index ]
309
+ # refresh system prompt when all messages are deleted
310
+ messages [0 ] = self ._generate_system_message (
311
+ exposed_entities , user_input
312
+ )
313
+
272
314
async def query (
273
315
self ,
274
316
user_input : conversation .ConversationInput ,
@@ -281,16 +323,27 @@ async def query(
281
323
max_tokens = self .entry .options .get (CONF_MAX_TOKENS , DEFAULT_MAX_TOKENS )
282
324
top_p = self .entry .options .get (CONF_TOP_P , DEFAULT_TOP_P )
283
325
temperature = self .entry .options .get (CONF_TEMPERATURE , DEFAULT_TEMPERATURE )
326
+ use_tools = self .entry .options .get (CONF_USE_TOOLS , DEFAULT_USE_TOOLS )
327
+ context_threshold = self .entry .options .get (
328
+ CONF_CONTEXT_THRESHOLD , DEFAULT_CONTEXT_THRESHOLD
329
+ )
284
330
functions = list (map (lambda s : s ["spec" ], self .get_functions ()))
285
331
function_call = "auto"
286
332
if n_requests == self .entry .options .get (
287
333
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
288
334
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION ,
289
335
):
290
336
function_call = "none"
337
+
338
+ tool_kwargs = {"functions" : functions , "function_call" : function_call }
339
+ if use_tools :
340
+ tool_kwargs = {
341
+ "tools" : [{"type" : "function" , "function" : func } for func in functions ],
342
+ "tool_choice" : function_call ,
343
+ }
344
+
291
345
if len (functions ) == 0 :
292
- functions = None
293
- function_call = None
346
+ tool_kwargs = {}
294
347
295
348
_LOGGER .info ("Prompt for %s: %s" , model , messages )
296
349
@@ -301,20 +354,28 @@ async def query(
301
354
top_p = top_p ,
302
355
temperature = temperature ,
303
356
user = user_input .conversation_id ,
304
- functions = functions ,
305
- function_call = function_call ,
357
+ ** tool_kwargs ,
306
358
)
307
359
308
360
_LOGGER .info ("Response %s" , response .model_dump (exclude_none = True ))
361
+
362
+ if response .usage .total_tokens > context_threshold :
363
+ await self .truncate_message_history (messages , exposed_entities , user_input )
364
+
309
365
choice : Choice = response .choices [0 ]
310
366
message = choice .message
367
+
311
368
if choice .finish_reason == "function_call" :
312
369
message = await self .execute_function_call (
313
370
user_input , messages , message , exposed_entities , n_requests + 1
314
371
)
372
+ if choice .finish_reason == "tool_calls" :
373
+ message = await self .execute_tool_calls (
374
+ user_input , messages , message , exposed_entities , n_requests + 1
375
+ )
315
376
return message
316
377
317
- def execute_function_call (
378
+ async def execute_function_call (
318
379
self ,
319
380
user_input : conversation .ConversationInput ,
320
381
messages ,
@@ -328,7 +389,7 @@ def execute_function_call(
328
389
None ,
329
390
)
330
391
if function is not None :
331
- return self .execute_function (
392
+ return await self .execute_function (
332
393
user_input ,
333
394
messages ,
334
395
message ,
@@ -366,3 +427,57 @@ async def execute_function(
366
427
}
367
428
)
368
429
return await self .query (user_input , messages , exposed_entities , n_requests )
430
+
431
+ async def execute_tool_calls (
432
+ self ,
433
+ user_input : conversation .ConversationInput ,
434
+ messages ,
435
+ message : ChatCompletionMessage ,
436
+ exposed_entities ,
437
+ n_requests ,
438
+ ):
439
+ messages .append (message .model_dump (exclude_none = True ))
440
+ for tool in message .tool_calls :
441
+ function_name = tool .function .name
442
+ function = next (
443
+ (s for s in self .get_functions () if s ["spec" ]["name" ] == function_name ),
444
+ None ,
445
+ )
446
+ if function is not None :
447
+ result = await self .execute_tool_function (
448
+ user_input ,
449
+ tool ,
450
+ exposed_entities ,
451
+ function ,
452
+ )
453
+
454
+ messages .append (
455
+ {
456
+ "tool_call_id" : tool .id ,
457
+ "role" : "tool" ,
458
+ "name" : function_name ,
459
+ "content" : str (result ),
460
+ }
461
+ )
462
+ else :
463
+ raise FunctionNotFound (function_name )
464
+ return await self .query (user_input , messages , exposed_entities , n_requests )
465
+
466
+ async def execute_tool_function (
467
+ self ,
468
+ user_input : conversation .ConversationInput ,
469
+ tool ,
470
+ exposed_entities ,
471
+ function ,
472
+ ):
473
+ function_executor = get_function_executor (function ["function" ]["type" ])
474
+
475
+ try :
476
+ arguments = json .loads (tool .function .arguments )
477
+ except json .decoder .JSONDecodeError as err :
478
+ raise ParseArgumentsFailed (tool .function .arguments ) from err
479
+
480
+ result = await function_executor .execute (
481
+ self .hass , function ["function" ], arguments , user_input , exposed_entities
482
+ )
483
+ return result
0 commit comments