Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CONF_AZURE_VERSION,
CONF_AZURE_BASE_URL,
CONF_AZURE_DEPLOYMENT,
CONF_ANTHROPIC_BASE_URL,
CONF_CUSTOM_OPENAI_ENDPOINT,
CONF_RETENTION_TIME,
CONF_MEMORY_PATHS,
Expand Down Expand Up @@ -108,6 +109,8 @@ async def async_setup_entry(hass, entry):
CONF_AZURE_BASE_URL: entry.data.get(CONF_AZURE_BASE_URL),
CONF_AZURE_DEPLOYMENT: entry.data.get(CONF_AZURE_DEPLOYMENT),
CONF_AZURE_VERSION: entry.data.get(CONF_AZURE_VERSION),
# Anthropic specific
CONF_ANTHROPIC_BASE_URL: entry.data.get(CONF_ANTHROPIC_BASE_URL),
# Custom OpenAI specific
CONF_CUSTOM_OPENAI_ENDPOINT: entry.data.get(CONF_CUSTOM_OPENAI_ENDPOINT),
# AWS specific
Expand Down
19 changes: 17 additions & 2 deletions custom_components/llmvision/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CONF_AZURE_VERSION,
CONF_AZURE_BASE_URL,
CONF_AZURE_DEPLOYMENT,
CONF_ANTHROPIC_BASE_URL,
CONF_CUSTOM_OPENAI_ENDPOINT,
CONF_RETENTION_TIME,
CONF_TIMELINE_LANGUAGE,
Expand Down Expand Up @@ -56,6 +57,7 @@
DEFAULT_OPENROUTER_MODEL,
ENDPOINT_OPENWEBUI,
ENDPOINT_AZURE,
ENDPOINT_ANTHROPIC,
ENDPOINT_OPENROUTER,
CONF_CONTEXT_WINDOW,
CONF_THINKING_BUDGET,
Expand Down Expand Up @@ -745,7 +747,11 @@ async def async_step_anthropic(self, user_input=None):
{
vol.Required(CONF_API_KEY): selector(
{"text": {"type": "password"}}
)
),
vol.Optional(
CONF_ANTHROPIC_BASE_URL,
default=ENDPOINT_ANTHROPIC,
): str,
}
),
{"collapsed": False},
Expand Down Expand Up @@ -798,7 +804,12 @@ async def async_step_anthropic(self, user_input=None):
self.init_info = self._get_reconfigure_entry().data
# Re-nest the flat config entry data into sections
suggested = {
"connection_section": {CONF_API_KEY: self.init_info.get(CONF_API_KEY)},
"connection_section": {
CONF_API_KEY: self.init_info.get(CONF_API_KEY),
CONF_ANTHROPIC_BASE_URL: self.init_info.get(
CONF_ANTHROPIC_BASE_URL, ENDPOINT_ANTHROPIC
),
},
"model_section": {
CONF_DEFAULT_MODEL: self.init_info.get(
CONF_DEFAULT_MODEL, DEFAULT_ANTHROPIC_MODEL
Expand All @@ -820,6 +831,10 @@ async def async_step_anthropic(self, user_input=None):
self.hass,
api_key=user_input[CONF_API_KEY],
model=user_input[CONF_DEFAULT_MODEL],
endpoint={
"base_url": user_input.get(CONF_ANTHROPIC_BASE_URL)
or ENDPOINT_ANTHROPIC
},
)
await anthropic.validate()
# add the mode to user_input
Expand Down
3 changes: 3 additions & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
CONF_AZURE_DEPLOYMENT = "azure_deployment"
CONF_AZURE_VERSION = "azure_version"

# Anthropic specific
CONF_ANTHROPIC_BASE_URL = "anthropic_base_url"

# AWS specific
CONF_AWS_ACCESS_KEY_ID = "aws_access_key_id"
CONF_AWS_SECRET_ACCESS_KEY = "aws_secret_access_key"
Expand Down
47 changes: 42 additions & 5 deletions custom_components/llmvision/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CONF_AZURE_BASE_URL,
CONF_AZURE_DEPLOYMENT,
CONF_AZURE_VERSION,
CONF_ANTHROPIC_BASE_URL,
CONF_CUSTOM_OPENAI_ENDPOINT,
CONF_AWS_ACCESS_KEY_ID,
CONF_AWS_SECRET_ACCESS_KEY,
Expand Down Expand Up @@ -980,8 +981,14 @@ def supports_structured_output(self) -> bool:

class Anthropic(Provider):

def __init__(self, hass: HomeAssistant, api_key: str, model: str):
super().__init__(hass, api_key, model)
def __init__(
self,
hass: HomeAssistant,
api_key: str,
model: str,
endpoint={"base_url": ENDPOINT_ANTHROPIC},
):
super().__init__(hass, api_key, model, endpoint=endpoint)

def supports_structured_output(self) -> bool:
"""Return True if provider supports structured output."""
Expand All @@ -994,9 +1001,33 @@ def _generate_headers(self) -> dict:
"anthropic-version": VERSION_ANTHROPIC,
}

def _get_request_url(self) -> str:
"""Resolve the messages endpoint, honoring a custom API base if configured.

Accepts a bare host (``https://proxy.example.com``), a versioned base
(``.../v1``) or a full messages endpoint (``.../v1/messages``) and
normalizes it to the Anthropic ``/v1/messages`` path.
"""
if isinstance(self.endpoint, dict):
url = self.endpoint.get("base_url") or ENDPOINT_ANTHROPIC
else:
url = self.endpoint or ENDPOINT_ANTHROPIC

if not isinstance(url, str) or not url.strip():
return ENDPOINT_ANTHROPIC

normalized_url = url.strip().rstrip("/")
if normalized_url.endswith("/messages"):
return normalized_url
if normalized_url.endswith("/v1"):
return f"{normalized_url}/messages"
return f"{normalized_url}/v1/messages"

async def _make_request(self, data: dict) -> str:
headers = self._generate_headers()
response = await self._post(url=ENDPOINT_ANTHROPIC, headers=headers, data=data)
response = await self._post(
url=self._get_request_url(), headers=headers, data=data
)

# Handle tool use response for structured output
if "content" in response and len(response["content"]) > 0:
Expand Down Expand Up @@ -1141,7 +1172,7 @@ async def validate(self) -> None | ServiceValidationError:
"temperature": 0.5,
}
await self._post(
url=f"https://api.anthropic.com/v1/messages", headers=header, data=payload
url=self._get_request_url(), headers=header, data=payload
)


Expand Down Expand Up @@ -2076,7 +2107,13 @@ def create(

if provider_name == "Anthropic":
return Anthropic(
hass, api_key=cast(str, config.get(CONF_API_KEY) or ""), model=model
hass,
api_key=cast(str, config.get(CONF_API_KEY) or ""),
model=model,
endpoint={
"base_url": config.get(CONF_ANTHROPIC_BASE_URL)
or ENDPOINT_ANTHROPIC
},
)

if provider_name == "Google":
Expand Down
6 changes: 4 additions & 2 deletions custom_components/llmvision/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@
"name": "Connection",
"description": "Anthropic authentication",
"data": {
"api_key": "API key"
"api_key": "API key",
"anthropic_base_url": "API base URL"
},
"data_description": {
"api_key": "Make sure to add credit to your Anthropic account."
"api_key": "Make sure to add credit to your Anthropic account.",
"anthropic_base_url": "Base URL of the Anthropic-compatible API. Leave as the default to use Anthropic directly, or point it at a proxy/gateway (e.g. https://your-proxy.example.com)."
}
},
"model_section": {
Expand Down
6 changes: 4 additions & 2 deletions custom_components/llmvision/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@
"name": "Connection",
"description": "Anthropic authentication",
"data": {
"api_key": "API key"
"api_key": "API key",
"anthropic_base_url": "API base URL"
},
"data_description": {
"api_key": "Make sure to add credit to your Anthropic account."
"api_key": "Make sure to add credit to your Anthropic account.",
"anthropic_base_url": "Base URL of the Anthropic-compatible API. Leave as the default to use Anthropic directly, or point it at a proxy/gateway (e.g. https://your-proxy.example.com)."
}
},
"model_section": {
Expand Down
42 changes: 41 additions & 1 deletion tests/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
CONF_TIMELINE_LANGUAGE,
CONF_TITLE_PROMPT,
CONF_TOP_P,
CONF_ANTHROPIC_BASE_URL,
DEFAULT_OPENAI_MODEL,
ENDPOINT_ANTHROPIC,
ENDPOINT_AZURE,
ENDPOINT_OPENROUTER,
ENDPOINT_OPENWEBUI,
Expand Down Expand Up @@ -390,6 +392,40 @@ async def test_openai_creates_entry_after_successful_validation(self, build_flow
)
provider_instance.validate.assert_awaited_once_with()

@pytest.mark.asyncio
async def test_anthropic_passes_custom_base_url(self, build_flow):
"""A custom Anthropic API base should be forwarded to the provider."""
flow = build_flow(init_info={CONF_PROVIDER: "Anthropic"})
user_input = {
"connection_section": {
CONF_API_KEY: "secret",
CONF_ANTHROPIC_BASE_URL: "https://proxy.example.com",
},
"model_section": {
CONF_DEFAULT_MODEL: "claude-3-7-sonnet-latest",
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
CONF_THINKING_BUDGET: 0,
},
}
provider_instance = Mock(validate=AsyncMock())

with patch(
"custom_components.llmvision.config_flow.Anthropic",
return_value=provider_instance,
) as anthropic_cls:
result = await flow.async_step_anthropic(user_input)

assert result["type"] == "create_entry"
assert result["data"][CONF_ANTHROPIC_BASE_URL] == "https://proxy.example.com"
anthropic_cls.assert_called_once_with(
flow.hass,
api_key="secret",
model="claude-3-7-sonnet-latest",
endpoint={"base_url": "https://proxy.example.com"},
)
provider_instance.validate.assert_awaited_once_with()

@pytest.mark.asyncio
async def test_openai_shows_error_when_validation_fails(self, build_flow):
"""Validation failures should return the handshake form error."""
Expand Down Expand Up @@ -590,7 +626,11 @@ async def test_localai_reconfigure_updates_existing_entry(self, build_flow):
"Anthropic Claude",
lambda flow: (
(flow.hass,),
{"api_key": "secret", "model": "claude-3-7-sonnet-latest"},
{
"api_key": "secret",
"model": "claude-3-7-sonnet-latest",
"endpoint": {"base_url": ENDPOINT_ANTHROPIC},
},
),
"empty_api_key",
),
Expand Down
84 changes: 83 additions & 1 deletion tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@
CONF_AZURE_DEPLOYMENT,
CONF_AZURE_VERSION,
CONF_CUSTOM_OPENAI_ENDPOINT,
CONF_ANTHROPIC_BASE_URL,
CONF_IP_ADDRESS,
CONF_HTTPS,
CONF_PORT,
CONF_THINK,
CONF_THINKING_BUDGET,
ENDPOINT_GROQ,
ENDPOINT_ANTHROPIC,
DEFAULT_OPENAI_MODEL,
DEFAULT_ANTHROPIC_MODEL,
DEFAULT_AZURE_MODEL,
Expand Down Expand Up @@ -780,6 +782,67 @@ def test_generate_headers(self, mock_hass):
assert headers["content-type"] == "application/json"
assert "anthropic-version" in headers

def test_default_endpoint(self, mock_hass):
"""Test Anthropic defaults to the official endpoint."""
with patch("custom_components.llmvision.providers.async_get_clientsession"):
anthropic = Anthropic(mock_hass, "test_api_key", "claude-3")

assert anthropic._get_request_url() == ENDPOINT_ANTHROPIC

@pytest.mark.parametrize(
("base_url", "expected"),
[
# Full messages endpoint is used as-is
(
"https://proxy.example.com/v1/messages",
"https://proxy.example.com/v1/messages",
),
# Trailing slash is normalized away
(
"https://proxy.example.com/v1/messages/",
"https://proxy.example.com/v1/messages",
),
# Versioned base gets /messages appended
(
"https://proxy.example.com/v1",
"https://proxy.example.com/v1/messages",
),
# Bare host gets the full /v1/messages path appended
(
"https://proxy.example.com",
"https://proxy.example.com/v1/messages",
),
(
"https://proxy.example.com/anthropic",
"https://proxy.example.com/anthropic/v1/messages",
),
],
)
def test_get_request_url_custom_base(self, mock_hass, base_url, expected):
"""Custom API bases are normalized to a messages endpoint."""
with patch("custom_components.llmvision.providers.async_get_clientsession"):
anthropic = Anthropic(
mock_hass,
"test_api_key",
"claude-3",
endpoint={"base_url": base_url},
)

assert anthropic._get_request_url() == expected

@pytest.mark.parametrize("base_url", ["", " ", None])
def test_get_request_url_falls_back_to_default(self, mock_hass, base_url):
"""Empty/blank custom bases fall back to the official endpoint."""
with patch("custom_components.llmvision.providers.async_get_clientsession"):
anthropic = Anthropic(
mock_hass,
"test_api_key",
"claude-3",
endpoint={"base_url": base_url},
)

assert anthropic._get_request_url() == ENDPOINT_ANTHROPIC


class TestGoogle:
"""Test Google provider class."""
Expand Down Expand Up @@ -1040,7 +1103,7 @@ def test_create_azure(self, mock_hass):
assert isinstance(provider, AzureOpenAI)

def test_create_anthropic(self, mock_hass):
"""Test ProviderFactory creates Anthropic provider."""
"""Test ProviderFactory creates Anthropic provider with default endpoint."""
config = {CONF_API_KEY: "test_key"}

with patch("custom_components.llmvision.providers.async_get_clientsession"):
Expand All @@ -1049,6 +1112,25 @@ def test_create_anthropic(self, mock_hass):
)

assert isinstance(provider, Anthropic)
assert provider._get_request_url() == ENDPOINT_ANTHROPIC

def test_create_anthropic_custom_base_url(self, mock_hass):
"""Test ProviderFactory passes a custom Anthropic base URL."""
config = {
CONF_API_KEY: "test_key",
CONF_ANTHROPIC_BASE_URL: "https://proxy.example.com",
}

with patch("custom_components.llmvision.providers.async_get_clientsession"):
provider = ProviderFactory.create(
mock_hass, "Anthropic", config, "claude-3"
)

assert isinstance(provider, Anthropic)
assert (
provider._get_request_url()
== "https://proxy.example.com/v1/messages"
)

def test_create_google(self, mock_hass):
"""Test ProviderFactory creates Google provider."""
Expand Down