Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions homeassistant/components/pi_hole/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""The pi_hole component."""

import asyncio
import logging
from typing import Any, Literal

import aiohttp
from aiohttp import DummyCookieJar
from hole import Hole, HoleV5, HoleV6
from hole.exceptions import HoleConnectionError, HoleError

Expand All @@ -13,17 +16,23 @@
CONF_LOCATION,
CONF_SSL,
CONF_VERIFY_SSL,
EVENT_HOMEASSISTANT_CLOSE,
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.aiohttp_client import (
async_create_clientsession,
async_get_clientsession,
)

from .const import CONF_STATISTICS_ONLY, DOMAIN
from .coordinator import PiHoleConfigEntry, PiHoleData, PiHoleUpdateCoordinator

_LOGGER = logging.getLogger(__name__)

DATA_V6_CLIENTSESSIONS = f"{DOMAIN}_v6_clientsessions"


PLATFORMS = [
Platform.BINARY_SENSOR,
Expand Down Expand Up @@ -110,7 +119,10 @@ def api_by_version(

if password is None:
password = entry.get(CONF_API_KEY, "")
session = async_get_clientsession(hass, entry[CONF_VERIFY_SSL])
if version == 6:
session = _async_get_v6_session(hass, entry[CONF_VERIFY_SSL])
else:
session = async_get_clientsession(hass, entry[CONF_VERIFY_SSL])
Comment on lines +122 to +125
hole_kwargs = {
"host": entry[CONF_HOST],
"session": session,
Expand All @@ -128,45 +140,87 @@ def api_by_version(
return Hole(**hole_kwargs)


@callback
def _async_get_v6_session(
hass: HomeAssistant, verify_ssl: bool
) -> aiohttp.ClientSession:
"""Get a session with an isolated cookie jar for the Pi-hole v6 API.

The session opts out of the auto-cleanup tied to the current config entry,
since the cache is shared across entries — otherwise the first entry to
unload would detach a session still in use by the others. Lifetime is
bound to Home Assistant shutdown instead.
"""
sessions: dict[bool, aiohttp.ClientSession] = hass.data.setdefault(
DATA_V6_CLIENTSESSIONS, {}
)
session = sessions.get(verify_ssl)
if session is None or session.closed:
session = async_create_clientsession(
hass, verify_ssl, auto_cleanup=False, cookie_jar=DummyCookieJar()
)
sessions[verify_ssl] = session

@callback
def _close(_event: Event) -> None:
session.detach()
sessions.pop(verify_ssl, None)

hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _close)
return session


async def _async_is_v6_api(hass: HomeAssistant, entry: dict[str, Any]) -> bool:
"""Check if the Pi-hole instance exposes the v6 API."""
protocol = "https" if entry.get(CONF_SSL) else "http"
session = _async_get_v6_session(hass, entry[CONF_VERIFY_SSL])
url = f"{protocol}://{entry[CONF_HOST]}/api/info/version"

async with asyncio.timeout(5):
async with session.get(url) as response:
try:
data: Any = await response.json()
except aiohttp.ContentTypeError, ValueError:
return False

if not isinstance(data, dict):
return False

if response.status == 200:
return isinstance(data.get("version"), dict)

if response.status == 401:
error = data.get("error")
return isinstance(error, dict) and error.get("key") == "unauthorized"

return False


async def determine_api_version(
hass: HomeAssistant, entry: dict[str, Any]
) -> Literal[5, 6]:
"""Determine the API version of the Pi-hole instance without requiring authentication.

Neither API v5 or v6 provides an endpoint to check the version without authentication.
Version 6 provides other enddpoints that do not require authentication, so we can use those to determine the version
version 5 returns an empty list in response to unauthenticated requests.
Version 6 returns either version data or a distinct unauthorized error from
/api/info/version, so we can use that endpoint to determine the version.
Version 5 returns an empty list in response to unauthenticated requests.
Comment on lines 202 to +206
Because we are using endpoints that are not designed for this purpose, we should log liberally to help with debugging.
"""

holeV6 = api_by_version(hass, entry, 6, password="wrong_password")
try:
await holeV6.authenticate()
except HoleConnectionError as err:
_LOGGER.error(
"Unexpected error connecting to Pi-hole v6 API at %s: %s. Trying version 5 API",
holeV6.base_url,
err,
)
# Ideally python-hole would raise a specific exception for authentication failures
except HoleError as ex_v6:
if str(ex_v6) == "Authentication failed: Invalid password":
if await _async_is_v6_api(hass, entry):
_LOGGER.debug(
"Success connecting to Pi-hole at %s without auth, API version is : %s",
holeV6.base_url,
6,
"Response from v6 API without auth, Pi-hole API version 6 probably detected at %s",
entry[CONF_HOST],
)
return 6
_LOGGER.debug(
"Connection to %s failed: %s, trying API version 5", holeV6.base_url, ex_v6
)
else:
# It seems that occasionally the auth can succeed unexpectedly when there is a valid session
_LOGGER.warning(
"Authenticated with %s through v6 API, but succeeded with an incorrect password. This is a known bug",
holeV6.base_url,
except (TimeoutError, aiohttp.ClientError) as err:
_LOGGER.error(
"Unexpected error connecting to Pi-hole v6 API at %s: %s. Trying version 5 API",
entry[CONF_HOST],
err,
)
return 6

holeV5 = api_by_version(hass, entry, 5, password="wrong_token")
try:
await holeV5.get_data()
Expand All @@ -190,6 +244,6 @@ async def determine_api_version(
)
_LOGGER.debug(
"Could not determine pi-hole API version at: %s",
holeV6.base_url,
entry[CONF_HOST],
)
raise HoleError("Could not determine Pi-hole API version")
72 changes: 49 additions & 23 deletions tests/components/pi_hole/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Tests for the pi_hole component."""

from collections.abc import Generator
from contextlib import ExitStack, contextmanager
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

from aiohttp import DummyCookieJar
from hole.exceptions import HoleConnectionError, HoleError

from homeassistant.components.pi_hole.const import (
Expand Down Expand Up @@ -140,6 +143,8 @@
LOCATION = "location"
NAME = "Pi hole"
API_KEY = "apikey"
APP_PASSWORD = "app_password"
VALID_V6_PASSWORDS = ("newkey", "apikey", APP_PASSWORD)
API_VERSION = 6
SSL = False
VERIFY_SSL = True
Expand Down Expand Up @@ -206,6 +211,7 @@ def _create_mocked_hole(
incorrect_app_password: bool = False,
wrong_host: bool = False,
ftl_error: bool = False,
require_cookie_free_app_password: bool = False,
) -> MagicMock:
"""Return a mocked Hole API object with side effects based on constructor args."""

Expand All @@ -221,17 +227,22 @@ async def authenticate_side_effect(*_args, **_kwargs):
if wrong_host:
raise HoleConnectionError("Cannot authenticate with Pi-hole: err")
password = getattr(mocked_hole, "password", None)
cookie_jar = getattr(
getattr(mocked_hole, "session", None), "cookie_jar", None
)

if (
raise_exception
or incorrect_app_password
or api_version == 5
or (api_version == 6 and password not in ["newkey", "apikey"])
require_cookie_free_app_password
and password == APP_PASSWORD
and not isinstance(cookie_jar, DummyCookieJar)
):
raise HoleError("Authentication failed: Invalid password")

if api_version == 6 and (
incorrect_app_password or password not in VALID_V6_PASSWORDS
):
if api_version == 6 and (
incorrect_app_password or password not in ["newkey", "apikey"]
):
raise HoleError("Authentication failed: Invalid password")
raise HoleError("Authentication failed: Invalid password")
if raise_exception or incorrect_app_password or api_version == 5:
raise HoleConnectionError

async def get_data_side_effect(*_args, **_kwargs):
Expand All @@ -244,10 +255,10 @@ async def get_data_side_effect(*_args, **_kwargs):
raise_exception
or incorrect_app_password
or (api_version == 5 and (not api_token or api_token == "wrong_token"))
or (api_version == 6 and password not in ["newkey", "apikey"])
or (api_version == 6 and password not in VALID_V6_PASSWORDS)
):
mocked_hole.data = [] if api_version == 5 else {}
elif password in ["newkey", "apikey"] or api_token in ["newkey", "apikey"]:
elif password in VALID_V6_PASSWORDS or api_token in ("newkey", "apikey"):
mocked_hole.data = ZERO_DATA_V6 if api_version == 6 else ZERO_DATA

async def ftl_side_effect():
Expand All @@ -256,10 +267,8 @@ async def ftl_side_effect():
mocked_hole.authenticate = AsyncMock(side_effect=authenticate_side_effect)
mocked_hole.get_data = AsyncMock(side_effect=get_data_side_effect)

if ftl_error:
# two unauthenticated instances are created in `determine_api_version` before aync_try_connect is called
if len(instances) > 1:
mocked_hole.get_data = AsyncMock(side_effect=ftl_side_effect)
if ftl_error and instances:
mocked_hole.get_data = AsyncMock(side_effect=ftl_side_effect)
mocked_hole.get_versions = AsyncMock(return_value=None)
mocked_hole.enable = AsyncMock()
mocked_hole.disable = AsyncMock()
Expand Down Expand Up @@ -293,28 +302,45 @@ async def ftl_side_effect():
return mocked_hole

# Return a factory function for patching
make_mock.api_version = api_version
make_mock.instances = instances
make_mock.wrong_host = wrong_host
return make_mock


def _patch_init_hole(mocked_hole):
"""Patch the Hole class in the main integration."""
@contextmanager
def _patch_hole(mocked_hole: MagicMock, patch_target: str) -> Generator[MagicMock]:
"""Patch the Hole class and API version detection."""

def side_effect(*args, **kwargs):
return mocked_hole(**kwargs)

return patch("homeassistant.components.pi_hole.Hole", side_effect=side_effect)
async def is_v6_api_side_effect(*_args, **_kwargs) -> bool:
if mocked_hole.wrong_host:
raise HoleConnectionError("Cannot fetch data from Pi-hole: err")
return mocked_hole.api_version == 6

with ExitStack() as stack:
patched_hole = stack.enter_context(patch(patch_target, side_effect=side_effect))
stack.enter_context(
patch(
"homeassistant.components.pi_hole._async_is_v6_api",
side_effect=is_v6_api_side_effect,
)
)
yield patched_hole


def _patch_init_hole(mocked_hole):
"""Patch the Hole class in the main integration."""

return _patch_hole(mocked_hole, "homeassistant.components.pi_hole.Hole")


def _patch_config_flow_hole(mocked_hole):
"""Patch the Hole class in the config flow."""

def side_effect(*args, **kwargs):
return mocked_hole(**kwargs)

return patch(
"homeassistant.components.pi_hole.config_flow.Hole", side_effect=side_effect
)
return _patch_hole(mocked_hole, "homeassistant.components.pi_hole.config_flow.Hole")


def _patch_setup_hole():
Expand Down
30 changes: 30 additions & 0 deletions tests/components/pi_hole/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.data_entry_flow import FlowResultType

from . import (
APP_PASSWORD,
CONFIG_DATA_DEFAULTS,
CONFIG_ENTRY_WITH_API_KEY,
CONFIG_FLOW_USER,
Expand Down Expand Up @@ -67,6 +68,35 @@ async def test_flow_user_with_api_key_v6(hass: HomeAssistant) -> None:
assert result["reason"] == "already_configured"


async def test_flow_user_with_app_password_v6(hass: HomeAssistant) -> None:
"""Test user initialized flow with a v6 app password."""
mocked_hole = _create_mocked_hole(
has_data=False, api_version=6, require_cookie_free_app_password=True
)
app_password_input = {**CONFIG_FLOW_USER, CONF_API_KEY: APP_PASSWORD}
with (
_patch_init_hole(mocked_hole),
_patch_config_flow_hole(mocked_hole),
_patch_setup_hole() as mock_setup,
):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
)

result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=app_password_input,
)

assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {
**CONFIG_ENTRY_WITH_API_KEY,
CONF_API_KEY: APP_PASSWORD,
}
mock_setup.assert_called_once()


async def test_flow_user_with_api_key_v5(hass: HomeAssistant) -> None:
"""Test user initialized flow with api key needed."""
mocked_hole = _create_mocked_hole(api_version=5)
Expand Down
Loading
Loading