Skip to content

Commit f1f5120

Browse files
authored
Fix #643 -- Add client_fractory to Redis check (#651)
Redis clients may leak connections between concurrent requests. To prevent this we need to instanciate a client per request.
1 parent 7d5015a commit f1f5120

File tree

3 files changed

+152
-24
lines changed

3 files changed

+152
-24
lines changed

docs/install.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ urlpatterns = [
5151
),
5252
(
5353
"health_check.contrib.redis.Redis",
54-
{"client": RedisClient.from_url("redis://localhost:6379")},
54+
{
55+
"client_factory": lambda: RedisClient.from_url(
56+
"redis://localhost:6379"
57+
)
58+
},
5559
),
5660
# AWS service status check
5761
(

health_check/contrib/redis.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import dataclasses
44
import logging
5+
import typing
6+
import warnings
57

68
from redis import exceptions
79
from redis.asyncio import Redis as RedisClient
@@ -22,35 +24,69 @@ class Redis(HealthCheck):
2224
including standard Redis, Sentinel, and Cluster clients.
2325
2426
Args:
25-
client: A Redis client instance (Redis, Sentinel master, or Cluster).
26-
Must be an async client from redis.asyncio.
27+
client_factory: A callable that returns an instance of a Redis client.
28+
client: Deprecated, use `client_factory` instead.
2729
2830
Examples:
2931
Using a standard Redis client:
3032
>>> from redis.asyncio import Redis as RedisClient
31-
>>> Redis(client=RedisClient(host='localhost', port=6379))
33+
>>> Redis(client_factory=lambda: RedisClient(host='localhost', port=6379))
3234
3335
Using from_url to create a client:
3436
>>> from redis.asyncio import Redis as RedisClient
35-
>>> Redis(client=RedisClient.from_url('redis://localhost:6379'))
37+
>>> Redis(client_factory=lambda: RedisClient.from_url('redis://localhost:6379'))
3638
3739
Using a Cluster client:
3840
>>> from redis.asyncio import RedisCluster
39-
>>> Redis(client=RedisCluster(host='localhost', port=7000))
41+
>>> Redis(client_factory=lambda: RedisCluster(host='localhost', port=7000))
4042
4143
Using a Sentinel client:
4244
>>> from redis.asyncio import Sentinel
43-
>>> sentinel = Sentinel([('localhost', 26379)])
44-
>>> Redis(client=sentinel.master_for('mymaster'))
45+
>>> Redis(client_factory=lambda: Sentinel([('localhost', 26379)]).master_for('mymaster'))
4546
4647
"""
4748

48-
client: RedisClient | RedisCluster = dataclasses.field(repr=False)
49+
client: RedisClient | RedisCluster | None = dataclasses.field(
50+
repr=False, default=None
51+
)
52+
client_factory: typing.Callable[[], RedisClient | RedisCluster] | None = (
53+
dataclasses.field(repr=False, default=None)
54+
)
55+
56+
def __post_init__(self):
57+
# Validate that exactly one of client or client_factory is provided
58+
if self.client is not None and self.client_factory is not None:
59+
raise ValueError(
60+
"Provide exactly one of `client` or `client_factory`, not both."
61+
)
62+
if self.client is None and self.client_factory is None:
63+
raise ValueError(
64+
"You must provide either `client` (deprecated) or `client_factory` "
65+
"when instantiating `Redis`."
66+
)
67+
68+
# Emit deprecation warning if using the old client parameter
69+
if self.client is not None:
70+
warnings.warn(
71+
"The `client` argument is deprecated and will be removed in a future version. "
72+
"Please use `client_factory` instead.",
73+
DeprecationWarning,
74+
stacklevel=2,
75+
)
4976

5077
async def run(self):
78+
# Create a new client for this health check request
79+
if self.client_factory is not None:
80+
client = self.client_factory()
81+
should_close = True
82+
else:
83+
# Use the deprecated client parameter (user manages lifecycle)
84+
client = self.client
85+
should_close = False
86+
5187
logger.debug("Pinging Redis client...")
5288
try:
53-
await self.client.ping()
89+
await client.ping()
5490
except ConnectionRefusedError as e:
5591
raise ServiceUnavailable(
5692
"Unable to connect to Redis: Connection was refused."
@@ -64,4 +100,6 @@ async def run(self):
64100
else:
65101
logger.debug("Connection established. Redis is healthy.")
66102
finally:
67-
await self.client.aclose()
103+
# Only close clients created by client_factory
104+
if should_close:
105+
await client.aclose()

tests/contrib/test_redis.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,47 +19,135 @@ class TestRedis:
1919

2020
@pytest.mark.asyncio
2121
async def test_redis__ok(self):
22-
"""Ping Redis successfully when using client parameter."""
22+
"""Ping Redis successfully when using client_factory parameter."""
2323
mock_client = mock.AsyncMock()
2424
mock_client.ping.return_value = True
2525

26-
check = RedisHealthCheck(client=mock_client)
26+
check = RedisHealthCheck(client_factory=lambda: mock_client)
2727
result = await check.get_result()
2828
assert result.error is None
2929
mock_client.ping.assert_called_once()
30+
mock_client.aclose.assert_called_once()
3031

3132
@pytest.mark.asyncio
3233
async def test_redis__connection_refused(self):
3334
"""Raise ServiceUnavailable when connection is refused."""
3435
mock_client = mock.AsyncMock()
3536
mock_client.ping.side_effect = ConnectionRefusedError("refused")
3637

37-
check = RedisHealthCheck(client=mock_client)
38+
check = RedisHealthCheck(client_factory=lambda: mock_client)
3839
result = await check.get_result()
3940
assert result.error is not None
4041
assert isinstance(result.error, ServiceUnavailable)
42+
mock_client.aclose.assert_called_once()
4143

4244
@pytest.mark.asyncio
4345
async def test_redis__timeout(self):
4446
"""Raise ServiceUnavailable when connection times out."""
4547
mock_client = mock.AsyncMock()
4648
mock_client.ping.side_effect = RedisTimeoutError("timeout")
4749

48-
check = RedisHealthCheck(client=mock_client)
50+
check = RedisHealthCheck(client_factory=lambda: mock_client)
4951
result = await check.get_result()
5052
assert result.error is not None
5153
assert isinstance(result.error, ServiceUnavailable)
54+
mock_client.aclose.assert_called_once()
5255

5356
@pytest.mark.asyncio
5457
async def test_redis__connection_error(self):
5558
"""Raise ServiceUnavailable when connection fails."""
5659
mock_client = mock.AsyncMock()
5760
mock_client.ping.side_effect = RedisConnectionError("connection error")
5861

59-
check = RedisHealthCheck(client=mock_client)
62+
check = RedisHealthCheck(client_factory=lambda: mock_client)
6063
result = await check.get_result()
6164
assert result.error is not None
6265
assert isinstance(result.error, ServiceUnavailable)
66+
mock_client.aclose.assert_called_once()
67+
68+
@pytest.mark.asyncio
69+
async def test_redis__client_deprecated(self):
70+
"""Verify DeprecationWarning is raised when using client parameter."""
71+
mock_client = mock.AsyncMock()
72+
mock_client.ping.return_value = True
73+
74+
with pytest.warns(
75+
DeprecationWarning, match="client.*deprecated.*client_factory"
76+
):
77+
check = RedisHealthCheck(client=mock_client)
78+
79+
result = await check.get_result()
80+
assert result.error is None
81+
mock_client.ping.assert_called_once()
82+
# User-provided client should NOT be closed by the health check
83+
mock_client.aclose.assert_not_called()
84+
85+
@pytest.mark.asyncio
86+
async def test_redis__factory_called_for_each_result(self):
87+
"""Verify client_factory is called per result and each client is closed."""
88+
call_count = 0
89+
created_clients = []
90+
91+
def factory():
92+
nonlocal call_count, created_clients
93+
call_count += 1
94+
client = mock.AsyncMock()
95+
client.ping.return_value = True
96+
created_clients.append(client)
97+
return client
98+
99+
check = RedisHealthCheck(client_factory=factory)
100+
# Factory should not be called eagerly during initialization
101+
assert call_count == 0, "Factory should not be called during initialization"
102+
103+
# Each request should use a newly created client
104+
result1 = await check.get_result()
105+
assert result1.error is None
106+
assert call_count == 1, "Factory should be called once for first request"
107+
108+
result2 = await check.get_result()
109+
assert result2.error is None
110+
assert call_count == 2, "Factory should be called again for second request"
111+
112+
# Ensure a distinct client was created and closed for each result
113+
assert len(created_clients) == 2
114+
assert created_clients[0] is not created_clients[1], (
115+
"Each request should create a distinct client"
116+
)
117+
created_clients[0].aclose.assert_called_once()
118+
created_clients[1].aclose.assert_called_once()
119+
120+
@pytest.mark.asyncio
121+
async def test_redis__client_not_closed_when_user_provided(self):
122+
"""Verify user-provided client is NOT closed by health check."""
123+
mock_client = mock.AsyncMock()
124+
mock_client.ping.return_value = True
125+
126+
with pytest.warns(DeprecationWarning):
127+
check = RedisHealthCheck(client=mock_client)
128+
129+
result = await check.get_result()
130+
assert result.error is None
131+
mock_client.ping.assert_called_once()
132+
# User is responsible for closing their own client
133+
mock_client.aclose.assert_not_called()
134+
135+
@pytest.mark.asyncio
136+
async def test_redis__validation_both_params(self):
137+
"""Verify error when both client and client_factory are provided."""
138+
mock_client = mock.AsyncMock()
139+
with pytest.raises(
140+
ValueError, match="Provide exactly one of `client` or `client_factory`"
141+
):
142+
RedisHealthCheck(client=mock_client, client_factory=lambda: mock_client)
143+
144+
@pytest.mark.asyncio
145+
async def test_redis__validation_neither_param(self):
146+
"""Verify error when neither client nor client_factory is provided."""
147+
with pytest.raises(
148+
ValueError, match="You must provide either `client`.*or `client_factory`"
149+
):
150+
RedisHealthCheck()
63151

64152
@pytest.mark.integration
65153
@pytest.mark.asyncio
@@ -71,11 +159,9 @@ async def test_redis__real_connection(self):
71159

72160
from redis.asyncio import Redis as RedisClient
73161

74-
client = RedisClient.from_url(redis_url)
75-
check = RedisHealthCheck(client=client)
162+
check = RedisHealthCheck(client_factory=lambda: RedisClient.from_url(redis_url))
76163
result = await check.get_result()
77164
assert result.error is None
78-
await client.aclose()
79165

80166
@pytest.mark.integration
81167
@pytest.mark.asyncio
@@ -97,11 +183,11 @@ async def test_redis__real_sentinel(self):
97183
host, port = node.strip().split(":")
98184
sentinels.append((host, int(port)))
99185

100-
# Create Sentinel and get master client
101-
sentinel = Sentinel(sentinels)
102-
master = sentinel.master_for(service_name)
186+
# Create factory that returns Sentinel master client
187+
def factory():
188+
sentinel = Sentinel(sentinels)
189+
return sentinel.master_for(service_name)
103190

104-
# Use the unified Redis check with the master client
105-
check = RedisHealthCheck(client=master)
191+
check = RedisHealthCheck(client_factory=factory)
106192
result = await check.get_result()
107193
assert result.error is None

0 commit comments

Comments
 (0)