11"""Tests for the pi_hole component."""
22
3+ from collections .abc import Generator
4+ from contextlib import ExitStack , contextmanager
35from typing import Any
46from unittest .mock import AsyncMock , MagicMock , patch
57
8+ from aiohttp import DummyCookieJar
69from hole .exceptions import HoleConnectionError , HoleError
710
811from homeassistant .components .pi_hole .const import (
140143LOCATION = "location"
141144NAME = "Pi hole"
142145API_KEY = "apikey"
146+ APP_PASSWORD = "app_password"
147+ VALID_V6_PASSWORDS = ("newkey" , "apikey" , APP_PASSWORD )
143148API_VERSION = 6
144149SSL = False
145150VERIFY_SSL = True
@@ -206,6 +211,7 @@ def _create_mocked_hole(
206211 incorrect_app_password : bool = False ,
207212 wrong_host : bool = False ,
208213 ftl_error : bool = False ,
214+ require_cookie_free_app_password : bool = False ,
209215) -> MagicMock :
210216 """Return a mocked Hole API object with side effects based on constructor args."""
211217
@@ -221,17 +227,22 @@ async def authenticate_side_effect(*_args, **_kwargs):
221227 if wrong_host :
222228 raise HoleConnectionError ("Cannot authenticate with Pi-hole: err" )
223229 password = getattr (mocked_hole , "password" , None )
230+ cookie_jar = getattr (
231+ getattr (mocked_hole , "session" , None ), "cookie_jar" , None
232+ )
224233
225234 if (
226- raise_exception
227- or incorrect_app_password
228- or api_version == 5
229- or (api_version == 6 and password not in ["newkey" , "apikey" ])
235+ require_cookie_free_app_password
236+ and password == APP_PASSWORD
237+ and not isinstance (cookie_jar , DummyCookieJar )
238+ ):
239+ raise HoleError ("Authentication failed: Invalid password" )
240+
241+ if api_version == 6 and (
242+ incorrect_app_password or password not in VALID_V6_PASSWORDS
230243 ):
231- if api_version == 6 and (
232- incorrect_app_password or password not in ["newkey" , "apikey" ]
233- ):
234- raise HoleError ("Authentication failed: Invalid password" )
244+ raise HoleError ("Authentication failed: Invalid password" )
245+ if raise_exception or incorrect_app_password or api_version == 5 :
235246 raise HoleConnectionError
236247
237248 async def get_data_side_effect (* _args , ** _kwargs ):
@@ -244,10 +255,10 @@ async def get_data_side_effect(*_args, **_kwargs):
244255 raise_exception
245256 or incorrect_app_password
246257 or (api_version == 5 and (not api_token or api_token == "wrong_token" ))
247- or (api_version == 6 and password not in [ "newkey" , "apikey" ] )
258+ or (api_version == 6 and password not in VALID_V6_PASSWORDS )
248259 ):
249260 mocked_hole .data = [] if api_version == 5 else {}
250- elif password in [ "newkey" , "apikey" ] or api_token in [ "newkey" , "apikey" ] :
261+ elif password in VALID_V6_PASSWORDS or api_token in ( "newkey" , "apikey" ) :
251262 mocked_hole .data = ZERO_DATA_V6 if api_version == 6 else ZERO_DATA
252263
253264 async def ftl_side_effect ():
@@ -256,10 +267,8 @@ async def ftl_side_effect():
256267 mocked_hole .authenticate = AsyncMock (side_effect = authenticate_side_effect )
257268 mocked_hole .get_data = AsyncMock (side_effect = get_data_side_effect )
258269
259- if ftl_error :
260- # two unauthenticated instances are created in `determine_api_version` before aync_try_connect is called
261- if len (instances ) > 1 :
262- mocked_hole .get_data = AsyncMock (side_effect = ftl_side_effect )
270+ if ftl_error and instances :
271+ mocked_hole .get_data = AsyncMock (side_effect = ftl_side_effect )
263272 mocked_hole .get_versions = AsyncMock (return_value = None )
264273 mocked_hole .enable = AsyncMock ()
265274 mocked_hole .disable = AsyncMock ()
@@ -293,28 +302,45 @@ async def ftl_side_effect():
293302 return mocked_hole
294303
295304 # Return a factory function for patching
305+ make_mock .api_version = api_version
296306 make_mock .instances = instances
307+ make_mock .wrong_host = wrong_host
297308 return make_mock
298309
299310
300- def _patch_init_hole (mocked_hole ):
301- """Patch the Hole class in the main integration."""
311+ @contextmanager
312+ def _patch_hole (mocked_hole : MagicMock , patch_target : str ) -> Generator [MagicMock ]:
313+ """Patch the Hole class and API version detection."""
302314
303315 def side_effect (* args , ** kwargs ):
304316 return mocked_hole (** kwargs )
305317
306- return patch ("homeassistant.components.pi_hole.Hole" , side_effect = side_effect )
318+ async def is_v6_api_side_effect (* _args , ** _kwargs ) -> bool :
319+ if mocked_hole .wrong_host :
320+ raise HoleConnectionError ("Cannot fetch data from Pi-hole: err" )
321+ return mocked_hole .api_version == 6
322+
323+ with ExitStack () as stack :
324+ patched_hole = stack .enter_context (patch (patch_target , side_effect = side_effect ))
325+ stack .enter_context (
326+ patch (
327+ "homeassistant.components.pi_hole._async_is_v6_api" ,
328+ side_effect = is_v6_api_side_effect ,
329+ )
330+ )
331+ yield patched_hole
332+
333+
334+ def _patch_init_hole (mocked_hole ):
335+ """Patch the Hole class in the main integration."""
336+
337+ return _patch_hole (mocked_hole , "homeassistant.components.pi_hole.Hole" )
307338
308339
309340def _patch_config_flow_hole (mocked_hole ):
310341 """Patch the Hole class in the config flow."""
311342
312- def side_effect (* args , ** kwargs ):
313- return mocked_hole (** kwargs )
314-
315- return patch (
316- "homeassistant.components.pi_hole.config_flow.Hole" , side_effect = side_effect
317- )
343+ return _patch_hole (mocked_hole , "homeassistant.components.pi_hole.config_flow.Hole" )
318344
319345
320346def _patch_setup_hole ():
0 commit comments