diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index 5bb9b7acc5b43f..d8a743bab4c961 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -1,8 +1,11 @@ """The pi_hole component.""" +import asyncio import logging from typing import Any, Literal +import aiohttp +from aiohttp import DummyCookieJar from hole import Hole, HoleV5, HoleV6 from hole.exceptions import HoleConnectionError, HoleError @@ -13,17 +16,23 @@ CONF_LOCATION, CONF_SSL, CONF_VERIFY_SSL, + EVENT_HOMEASSISTANT_CLOSE, Platform, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import entity_registry as er -from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.aiohttp_client import ( + async_create_clientsession, + async_get_clientsession, +) from .const import CONF_STATISTICS_ONLY, DOMAIN from .coordinator import PiHoleConfigEntry, PiHoleData, PiHoleUpdateCoordinator _LOGGER = logging.getLogger(__name__) +DATA_V6_CLIENTSESSIONS = f"{DOMAIN}_v6_clientsessions" + PLATFORMS = [ Platform.BINARY_SENSOR, @@ -110,7 +119,10 @@ def api_by_version( if password is None: password = entry.get(CONF_API_KEY, "") - session = async_get_clientsession(hass, entry[CONF_VERIFY_SSL]) + if version == 6: + session = _async_get_v6_session(hass, entry[CONF_VERIFY_SSL]) + else: + session = async_get_clientsession(hass, entry[CONF_VERIFY_SSL]) hole_kwargs = { "host": entry[CONF_HOST], "session": session, @@ -128,45 +140,87 @@ def api_by_version( return Hole(**hole_kwargs) +@callback +def _async_get_v6_session( + hass: HomeAssistant, verify_ssl: bool +) -> aiohttp.ClientSession: + """Get a session with an isolated cookie jar for the Pi-hole v6 API. + + The session opts out of the auto-cleanup tied to the current config entry, + since the cache is shared across entries — otherwise the first entry to + unload would detach a session still in use by the others. Lifetime is + bound to Home Assistant shutdown instead. + """ + sessions: dict[bool, aiohttp.ClientSession] = hass.data.setdefault( + DATA_V6_CLIENTSESSIONS, {} + ) + session = sessions.get(verify_ssl) + if session is None or session.closed: + session = async_create_clientsession( + hass, verify_ssl, auto_cleanup=False, cookie_jar=DummyCookieJar() + ) + sessions[verify_ssl] = session + + @callback + def _close(_event: Event) -> None: + session.detach() + sessions.pop(verify_ssl, None) + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, _close) + return session + + +async def _async_is_v6_api(hass: HomeAssistant, entry: dict[str, Any]) -> bool: + """Check if the Pi-hole instance exposes the v6 API.""" + protocol = "https" if entry.get(CONF_SSL) else "http" + session = _async_get_v6_session(hass, entry[CONF_VERIFY_SSL]) + url = f"{protocol}://{entry[CONF_HOST]}/api/info/version" + + async with asyncio.timeout(5): + async with session.get(url) as response: + try: + data: Any = await response.json() + except aiohttp.ContentTypeError, ValueError: + return False + + if not isinstance(data, dict): + return False + + if response.status == 200: + return isinstance(data.get("version"), dict) + + if response.status == 401: + error = data.get("error") + return isinstance(error, dict) and error.get("key") == "unauthorized" + + return False + + async def determine_api_version( hass: HomeAssistant, entry: dict[str, Any] ) -> Literal[5, 6]: """Determine the API version of the Pi-hole instance without requiring authentication. - Neither API v5 or v6 provides an endpoint to check the version without authentication. - Version 6 provides other enddpoints that do not require authentication, so we can use those to determine the version - version 5 returns an empty list in response to unauthenticated requests. + Version 6 returns either version data or a distinct unauthorized error from + /api/info/version, so we can use that endpoint to determine the version. + Version 5 returns an empty list in response to unauthenticated requests. Because we are using endpoints that are not designed for this purpose, we should log liberally to help with debugging. """ - holeV6 = api_by_version(hass, entry, 6, password="wrong_password") try: - await holeV6.authenticate() - except HoleConnectionError as err: - _LOGGER.error( - "Unexpected error connecting to Pi-hole v6 API at %s: %s. Trying version 5 API", - holeV6.base_url, - err, - ) - # Ideally python-hole would raise a specific exception for authentication failures - except HoleError as ex_v6: - if str(ex_v6) == "Authentication failed: Invalid password": + if await _async_is_v6_api(hass, entry): _LOGGER.debug( - "Success connecting to Pi-hole at %s without auth, API version is : %s", - holeV6.base_url, - 6, + "Response from v6 API without auth, Pi-hole API version 6 probably detected at %s", + entry[CONF_HOST], ) return 6 - _LOGGER.debug( - "Connection to %s failed: %s, trying API version 5", holeV6.base_url, ex_v6 - ) - else: - # It seems that occasionally the auth can succeed unexpectedly when there is a valid session - _LOGGER.warning( - "Authenticated with %s through v6 API, but succeeded with an incorrect password. This is a known bug", - holeV6.base_url, + except (TimeoutError, aiohttp.ClientError) as err: + _LOGGER.error( + "Unexpected error connecting to Pi-hole v6 API at %s: %s. Trying version 5 API", + entry[CONF_HOST], + err, ) - return 6 + holeV5 = api_by_version(hass, entry, 5, password="wrong_token") try: await holeV5.get_data() @@ -190,6 +244,6 @@ async def determine_api_version( ) _LOGGER.debug( "Could not determine pi-hole API version at: %s", - holeV6.base_url, + entry[CONF_HOST], ) raise HoleError("Could not determine Pi-hole API version") diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index c2edb51e066413..67192baf2ed5e5 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -1,8 +1,11 @@ """Tests for the pi_hole component.""" +from collections.abc import Generator +from contextlib import ExitStack, contextmanager from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +from aiohttp import DummyCookieJar from hole.exceptions import HoleConnectionError, HoleError from homeassistant.components.pi_hole.const import ( @@ -140,6 +143,8 @@ LOCATION = "location" NAME = "Pi hole" API_KEY = "apikey" +APP_PASSWORD = "app_password" +VALID_V6_PASSWORDS = ("newkey", "apikey", APP_PASSWORD) API_VERSION = 6 SSL = False VERIFY_SSL = True @@ -206,6 +211,7 @@ def _create_mocked_hole( incorrect_app_password: bool = False, wrong_host: bool = False, ftl_error: bool = False, + require_cookie_free_app_password: bool = False, ) -> MagicMock: """Return a mocked Hole API object with side effects based on constructor args.""" @@ -221,17 +227,22 @@ async def authenticate_side_effect(*_args, **_kwargs): if wrong_host: raise HoleConnectionError("Cannot authenticate with Pi-hole: err") password = getattr(mocked_hole, "password", None) + cookie_jar = getattr( + getattr(mocked_hole, "session", None), "cookie_jar", None + ) if ( - raise_exception - or incorrect_app_password - or api_version == 5 - or (api_version == 6 and password not in ["newkey", "apikey"]) + require_cookie_free_app_password + and password == APP_PASSWORD + and not isinstance(cookie_jar, DummyCookieJar) + ): + raise HoleError("Authentication failed: Invalid password") + + if api_version == 6 and ( + incorrect_app_password or password not in VALID_V6_PASSWORDS ): - if api_version == 6 and ( - incorrect_app_password or password not in ["newkey", "apikey"] - ): - raise HoleError("Authentication failed: Invalid password") + raise HoleError("Authentication failed: Invalid password") + if raise_exception or incorrect_app_password or api_version == 5: raise HoleConnectionError async def get_data_side_effect(*_args, **_kwargs): @@ -244,10 +255,10 @@ async def get_data_side_effect(*_args, **_kwargs): raise_exception or incorrect_app_password or (api_version == 5 and (not api_token or api_token == "wrong_token")) - or (api_version == 6 and password not in ["newkey", "apikey"]) + or (api_version == 6 and password not in VALID_V6_PASSWORDS) ): mocked_hole.data = [] if api_version == 5 else {} - elif password in ["newkey", "apikey"] or api_token in ["newkey", "apikey"]: + elif password in VALID_V6_PASSWORDS or api_token in ("newkey", "apikey"): mocked_hole.data = ZERO_DATA_V6 if api_version == 6 else ZERO_DATA async def ftl_side_effect(): @@ -256,10 +267,8 @@ async def ftl_side_effect(): mocked_hole.authenticate = AsyncMock(side_effect=authenticate_side_effect) mocked_hole.get_data = AsyncMock(side_effect=get_data_side_effect) - if ftl_error: - # two unauthenticated instances are created in `determine_api_version` before aync_try_connect is called - if len(instances) > 1: - mocked_hole.get_data = AsyncMock(side_effect=ftl_side_effect) + if ftl_error and instances: + mocked_hole.get_data = AsyncMock(side_effect=ftl_side_effect) mocked_hole.get_versions = AsyncMock(return_value=None) mocked_hole.enable = AsyncMock() mocked_hole.disable = AsyncMock() @@ -293,28 +302,45 @@ async def ftl_side_effect(): return mocked_hole # Return a factory function for patching + make_mock.api_version = api_version make_mock.instances = instances + make_mock.wrong_host = wrong_host return make_mock -def _patch_init_hole(mocked_hole): - """Patch the Hole class in the main integration.""" +@contextmanager +def _patch_hole(mocked_hole: MagicMock, patch_target: str) -> Generator[MagicMock]: + """Patch the Hole class and API version detection.""" def side_effect(*args, **kwargs): return mocked_hole(**kwargs) - return patch("homeassistant.components.pi_hole.Hole", side_effect=side_effect) + async def is_v6_api_side_effect(*_args, **_kwargs) -> bool: + if mocked_hole.wrong_host: + raise HoleConnectionError("Cannot fetch data from Pi-hole: err") + return mocked_hole.api_version == 6 + + with ExitStack() as stack: + patched_hole = stack.enter_context(patch(patch_target, side_effect=side_effect)) + stack.enter_context( + patch( + "homeassistant.components.pi_hole._async_is_v6_api", + side_effect=is_v6_api_side_effect, + ) + ) + yield patched_hole + + +def _patch_init_hole(mocked_hole): + """Patch the Hole class in the main integration.""" + + return _patch_hole(mocked_hole, "homeassistant.components.pi_hole.Hole") def _patch_config_flow_hole(mocked_hole): """Patch the Hole class in the config flow.""" - def side_effect(*args, **kwargs): - return mocked_hole(**kwargs) - - return patch( - "homeassistant.components.pi_hole.config_flow.Hole", side_effect=side_effect - ) + return _patch_hole(mocked_hole, "homeassistant.components.pi_hole.config_flow.Hole") def _patch_setup_hole(): diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py index e79f65b406e199..a554ad57023cbc 100644 --- a/tests/components/pi_hole/test_config_flow.py +++ b/tests/components/pi_hole/test_config_flow.py @@ -8,6 +8,7 @@ from homeassistant.data_entry_flow import FlowResultType from . import ( + APP_PASSWORD, CONFIG_DATA_DEFAULTS, CONFIG_ENTRY_WITH_API_KEY, CONFIG_FLOW_USER, @@ -67,6 +68,35 @@ async def test_flow_user_with_api_key_v6(hass: HomeAssistant) -> None: assert result["reason"] == "already_configured" +async def test_flow_user_with_app_password_v6(hass: HomeAssistant) -> None: + """Test user initialized flow with a v6 app password.""" + mocked_hole = _create_mocked_hole( + has_data=False, api_version=6, require_cookie_free_app_password=True + ) + app_password_input = {**CONFIG_FLOW_USER, CONF_API_KEY: APP_PASSWORD} + with ( + _patch_init_hole(mocked_hole), + _patch_config_flow_hole(mocked_hole), + _patch_setup_hole() as mock_setup, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + ) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=app_password_input, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == { + **CONFIG_ENTRY_WITH_API_KEY, + CONF_API_KEY: APP_PASSWORD, + } + mock_setup.assert_called_once() + + async def test_flow_user_with_api_key_v5(hass: HomeAssistant) -> None: """Test user initialized flow with api key needed.""" mocked_hole = _create_mocked_hole(api_version=5) diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index e45e9b2997e896..c0478f213d0f38 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -35,6 +35,56 @@ ) from tests.common import MockConfigEntry +from tests.test_util.aiohttp import AiohttpClientMocker + + +async def test_determine_api_version_v6( + hass: HomeAssistant, aioclient_mock: AiohttpClientMocker +) -> None: + """Test detecting a Pi-hole v6 API without authentication.""" + aioclient_mock.get( + "http://1.2.3.4:80/api/info/version", + status=401, + json={"error": {"key": "unauthorized", "message": "Unauthorized"}}, + ) + + assert await pi_hole.determine_api_version(hass, CONFIG_DATA_DEFAULTS) == 6 + + +async def test_determine_api_version_v6_without_authentication( + hass: HomeAssistant, aioclient_mock: AiohttpClientMocker +) -> None: + """Test detecting a Pi-hole v6 API with authentication disabled.""" + aioclient_mock.get( + "http://1.2.3.4:80/api/info/version", + status=200, + json={"version": {"core": {"local": {"version": "v6.0.0"}}}}, + ) + + assert await pi_hole.determine_api_version(hass, CONFIG_DATA_DEFAULTS) == 6 + + +@pytest.mark.parametrize( + ("status", "response_json"), + [ + (200, []), + (500, {"error": {"key": "internal_error"}}), + ], +) +async def test_is_v6_api_with_unexpected_response( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + status: int, + response_json: object, +) -> None: + """Test ignoring unexpected responses from the Pi-hole v6 version API.""" + aioclient_mock.get( + "http://1.2.3.4:80/api/info/version", + status=status, + json=response_json, + ) + + assert not await pi_hole._async_is_v6_api(hass, CONFIG_DATA_DEFAULTS) @pytest.mark.parametrize(