@@ -19,47 +19,135 @@ class TestRedis:
1919
2020 @pytest .mark .asyncio
2121 async def test_redis__ok (self ):
22- """Ping Redis successfully when using client parameter."""
22+ """Ping Redis successfully when using client_factory parameter."""
2323 mock_client = mock .AsyncMock ()
2424 mock_client .ping .return_value = True
2525
26- check = RedisHealthCheck (client = mock_client )
26+ check = RedisHealthCheck (client_factory = lambda : mock_client )
2727 result = await check .get_result ()
2828 assert result .error is None
2929 mock_client .ping .assert_called_once ()
30+ mock_client .aclose .assert_called_once ()
3031
3132 @pytest .mark .asyncio
3233 async def test_redis__connection_refused (self ):
3334 """Raise ServiceUnavailable when connection is refused."""
3435 mock_client = mock .AsyncMock ()
3536 mock_client .ping .side_effect = ConnectionRefusedError ("refused" )
3637
37- check = RedisHealthCheck (client = mock_client )
38+ check = RedisHealthCheck (client_factory = lambda : mock_client )
3839 result = await check .get_result ()
3940 assert result .error is not None
4041 assert isinstance (result .error , ServiceUnavailable )
42+ mock_client .aclose .assert_called_once ()
4143
4244 @pytest .mark .asyncio
4345 async def test_redis__timeout (self ):
4446 """Raise ServiceUnavailable when connection times out."""
4547 mock_client = mock .AsyncMock ()
4648 mock_client .ping .side_effect = RedisTimeoutError ("timeout" )
4749
48- check = RedisHealthCheck (client = mock_client )
50+ check = RedisHealthCheck (client_factory = lambda : mock_client )
4951 result = await check .get_result ()
5052 assert result .error is not None
5153 assert isinstance (result .error , ServiceUnavailable )
54+ mock_client .aclose .assert_called_once ()
5255
5356 @pytest .mark .asyncio
5457 async def test_redis__connection_error (self ):
5558 """Raise ServiceUnavailable when connection fails."""
5659 mock_client = mock .AsyncMock ()
5760 mock_client .ping .side_effect = RedisConnectionError ("connection error" )
5861
59- check = RedisHealthCheck (client = mock_client )
62+ check = RedisHealthCheck (client_factory = lambda : mock_client )
6063 result = await check .get_result ()
6164 assert result .error is not None
6265 assert isinstance (result .error , ServiceUnavailable )
66+ mock_client .aclose .assert_called_once ()
67+
68+ @pytest .mark .asyncio
69+ async def test_redis__client_deprecated (self ):
70+ """Verify DeprecationWarning is raised when using client parameter."""
71+ mock_client = mock .AsyncMock ()
72+ mock_client .ping .return_value = True
73+
74+ with pytest .warns (
75+ DeprecationWarning , match = "client.*deprecated.*client_factory"
76+ ):
77+ check = RedisHealthCheck (client = mock_client )
78+
79+ result = await check .get_result ()
80+ assert result .error is None
81+ mock_client .ping .assert_called_once ()
82+ # User-provided client should NOT be closed by the health check
83+ mock_client .aclose .assert_not_called ()
84+
85+ @pytest .mark .asyncio
86+ async def test_redis__factory_called_for_each_result (self ):
87+ """Verify client_factory is called per result and each client is closed."""
88+ call_count = 0
89+ created_clients = []
90+
91+ def factory ():
92+ nonlocal call_count , created_clients
93+ call_count += 1
94+ client = mock .AsyncMock ()
95+ client .ping .return_value = True
96+ created_clients .append (client )
97+ return client
98+
99+ check = RedisHealthCheck (client_factory = factory )
100+ # Factory should not be called eagerly during initialization
101+ assert call_count == 0 , "Factory should not be called during initialization"
102+
103+ # Each request should use a newly created client
104+ result1 = await check .get_result ()
105+ assert result1 .error is None
106+ assert call_count == 1 , "Factory should be called once for first request"
107+
108+ result2 = await check .get_result ()
109+ assert result2 .error is None
110+ assert call_count == 2 , "Factory should be called again for second request"
111+
112+ # Ensure a distinct client was created and closed for each result
113+ assert len (created_clients ) == 2
114+ assert created_clients [0 ] is not created_clients [1 ], (
115+ "Each request should create a distinct client"
116+ )
117+ created_clients [0 ].aclose .assert_called_once ()
118+ created_clients [1 ].aclose .assert_called_once ()
119+
120+ @pytest .mark .asyncio
121+ async def test_redis__client_not_closed_when_user_provided (self ):
122+ """Verify user-provided client is NOT closed by health check."""
123+ mock_client = mock .AsyncMock ()
124+ mock_client .ping .return_value = True
125+
126+ with pytest .warns (DeprecationWarning ):
127+ check = RedisHealthCheck (client = mock_client )
128+
129+ result = await check .get_result ()
130+ assert result .error is None
131+ mock_client .ping .assert_called_once ()
132+ # User is responsible for closing their own client
133+ mock_client .aclose .assert_not_called ()
134+
135+ @pytest .mark .asyncio
136+ async def test_redis__validation_both_params (self ):
137+ """Verify error when both client and client_factory are provided."""
138+ mock_client = mock .AsyncMock ()
139+ with pytest .raises (
140+ ValueError , match = "Provide exactly one of `client` or `client_factory`"
141+ ):
142+ RedisHealthCheck (client = mock_client , client_factory = lambda : mock_client )
143+
144+ @pytest .mark .asyncio
145+ async def test_redis__validation_neither_param (self ):
146+ """Verify error when neither client nor client_factory is provided."""
147+ with pytest .raises (
148+ ValueError , match = "You must provide either `client`.*or `client_factory`"
149+ ):
150+ RedisHealthCheck ()
63151
64152 @pytest .mark .integration
65153 @pytest .mark .asyncio
@@ -71,11 +159,9 @@ async def test_redis__real_connection(self):
71159
72160 from redis .asyncio import Redis as RedisClient
73161
74- client = RedisClient .from_url (redis_url )
75- check = RedisHealthCheck (client = client )
162+ check = RedisHealthCheck (client_factory = lambda : RedisClient .from_url (redis_url ))
76163 result = await check .get_result ()
77164 assert result .error is None
78- await client .aclose ()
79165
80166 @pytest .mark .integration
81167 @pytest .mark .asyncio
@@ -97,11 +183,11 @@ async def test_redis__real_sentinel(self):
97183 host , port = node .strip ().split (":" )
98184 sentinels .append ((host , int (port )))
99185
100- # Create Sentinel and get master client
101- sentinel = Sentinel (sentinels )
102- master = sentinel .master_for (service_name )
186+ # Create factory that returns Sentinel master client
187+ def factory ():
188+ sentinel = Sentinel (sentinels )
189+ return sentinel .master_for (service_name )
103190
104- # Use the unified Redis check with the master client
105- check = RedisHealthCheck (client = master )
191+ check = RedisHealthCheck (client_factory = factory )
106192 result = await check .get_result ()
107193 assert result .error is None
0 commit comments