1919from homeassistant .util import ulid
2020from voluptuous_openapi import convert
2121
22+ from .providers import LLMMessage , create_provider
2223from .const import (
2324 AI_HUB_CHAT_URL ,
2425 CONF_CHAT_MODEL ,
2526 CONF_CHAT_URL ,
27+ CONF_LLM_PROVIDER ,
2628 CONF_CUSTOM_API_KEY ,
2729 CONF_MAX_HISTORY_MESSAGES ,
2830 CONF_MAX_TOKENS ,
@@ -76,6 +78,17 @@ def _get_request_ssl_setting(api_url: str, default_url: str) -> bool | None:
7678 return None
7779
7880
81+ def _get_provider_name (api_url : str , configured_provider : str | None = None ) -> str :
82+ """Select provider implementation from URL."""
83+ if configured_provider in {"openai_compatible" , "anthropic_compatible" }:
84+ return configured_provider
85+
86+ parsed = urlparse (api_url )
87+ if "anthropic" in parsed .path .lower () or parsed .netloc == "api.anthropic.com" :
88+ return "anthropic_compatible"
89+ return "openai_compatible"
90+
91+
7992class _AIHubEntityMixin :
8093 """Mixin class providing common initialization logic for AI Hub entities.
8194
@@ -283,15 +296,6 @@ async def _async_handle_chat_log(
283296 model_name = self .default_model
284297 _LOGGER .warning ("Model name was invalid, using default: %s" , model_name )
285298
286- request_params = {
287- "model" : model_name ,
288- "messages" : messages ,
289- "stream" : True ,
290- }
291-
292- if tools :
293- request_params ["tools" ] = tools
294-
295299 # Validate all message contents before sending
296300 for i , msg in enumerate (messages ):
297301 msg_content = msg .get ("content" )
@@ -311,6 +315,17 @@ async def _async_handle_chat_log(
311315 api_url = AI_HUB_CHAT_URL
312316 _LOGGER .warning ("API URL was invalid, using default: %s" , api_url )
313317
318+ provider_name = _get_provider_name (api_url , options .get (CONF_LLM_PROVIDER ))
319+
320+ request_params = {
321+ "model" : model_name ,
322+ "messages" : messages ,
323+ "stream" : True ,
324+ }
325+
326+ if tools :
327+ request_params ["tools" ] = tools
328+
314329 try :
315330 # Validate API key before making request
316331 if not self ._api_key :
@@ -322,11 +337,46 @@ async def _async_handle_chat_log(
322337 self ._api_key = str (self ._api_key )
323338
324339 _LOGGER .debug (
325- "API Request: model=%s, messages_count=%d" ,
340+ "API Request: provider=%s, model=%s, messages_count=%d" ,
341+ provider_name ,
326342 model_name ,
327343 len (messages )
328344 )
329345
346+ llm_messages = [
347+ LLMMessage (
348+ role = msg ["role" ],
349+ content = msg .get ("content" , "" ),
350+ tool_calls = msg .get ("tool_calls" ),
351+ tool_call_id = msg .get ("tool_call_id" ),
352+ )
353+ for msg in messages
354+ ]
355+
356+ provider = create_provider (
357+ provider_name ,
358+ {
359+ "api_key" : self ._api_key ,
360+ "model" : model_name ,
361+ "base_url" : api_url ,
362+ "temperature" : model_config .get ("temperature" , RECOMMENDED_TEMPERATURE ),
363+ "max_tokens" : model_config .get ("max_tokens" , RECOMMENDED_MAX_TOKENS ),
364+ },
365+ )
366+
367+ if provider_name == "anthropic_compatible" and provider is not None :
368+ response = await provider .complete (llm_messages , tools = tools )
369+ tool_calls = self ._convert_provider_tool_calls (response .tool_calls )
370+ assistant_content = conversation .AssistantContent (
371+ agent_id = self .entity_id ,
372+ content = response .content or None ,
373+ tool_calls = tool_calls or None ,
374+ native = response .raw_response ,
375+ )
376+ async for _ in chat_log .async_add_assistant_content (assistant_content ):
377+ pass
378+ return
379+
330380 # Call AI Hub API with streaming via HTTP
331381 headers = {
332382 "Authorization" : f"Bearer { self ._api_key } " ,
@@ -363,6 +413,37 @@ async def _async_handle_chat_log(
363413 _LOGGER .error ("Error calling AI Hub API: %s" , err )
364414 raise HomeAssistantError (ERROR_GETTING_RESPONSE ) from err
365415
416+ def _convert_provider_tool_calls (
417+ self ,
418+ tool_calls : list [dict [str , Any ]] | None ,
419+ ) -> list [llm .ToolInput ]:
420+ """Convert provider tool calls to Home Assistant ToolInput objects."""
421+ if not tool_calls :
422+ return []
423+
424+ converted : list [llm .ToolInput ] = []
425+ for tool_call in tool_calls :
426+ try :
427+ function_data = tool_call .get ("function" , {})
428+ arguments = function_data .get ("arguments" , {})
429+ if isinstance (arguments , str ):
430+ arguments = json .loads (arguments ) if arguments else {}
431+ if not isinstance (arguments , dict ):
432+ arguments = {"value" : arguments }
433+
434+ tool_id = tool_call .get ("id" ) or ulid .ulid_now ()
435+ converted .append (
436+ llm .ToolInput (
437+ id = tool_id ,
438+ tool_name = function_data .get ("name" , "tool" ),
439+ tool_args = arguments ,
440+ )
441+ )
442+ except Exception as err :
443+ _LOGGER .warning ("Failed to convert provider tool call: %s" , err )
444+
445+ return converted
446+
366447 async def _async_convert_chat_log_to_messages (
367448 self , chat_log : conversation .ChatLog
368449 ) -> list [dict [str , Any ]]:
0 commit comments