Skip to content
6 changes: 5 additions & 1 deletion docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ urlpatterns = [
),
(
"health_check.contrib.redis.Redis",
{"client": RedisClient.from_url("redis://localhost:6379")},
{
"client_factory": lambda: RedisClient.from_url(
"redis://localhost:6379"
)
},
),
# AWS service status check
(
Expand Down
58 changes: 48 additions & 10 deletions health_check/contrib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import dataclasses
import logging
import typing
import warnings

from redis import exceptions
from redis.asyncio import Redis as RedisClient
Expand All @@ -22,35 +24,69 @@ class Redis(HealthCheck):
including standard Redis, Sentinel, and Cluster clients.

Args:
client: A Redis client instance (Redis, Sentinel master, or Cluster).
Must be an async client from redis.asyncio.
client_factory: A callable that returns an instance of a Redis client.
client: Deprecated, use `client_factory` instead.

Examples:
Using a standard Redis client:
>>> from redis.asyncio import Redis as RedisClient
>>> Redis(client=RedisClient(host='localhost', port=6379))
>>> Redis(client_factory=lambda: RedisClient(host='localhost', port=6379))

Using from_url to create a client:
>>> from redis.asyncio import Redis as RedisClient
>>> Redis(client=RedisClient.from_url('redis://localhost:6379'))
>>> Redis(client_factory=lambda: RedisClient.from_url('redis://localhost:6379'))

Using a Cluster client:
>>> from redis.asyncio import RedisCluster
>>> Redis(client=RedisCluster(host='localhost', port=7000))
>>> Redis(client_factory=lambda: RedisCluster(host='localhost', port=7000))

Using a Sentinel client:
>>> from redis.asyncio import Sentinel
>>> sentinel = Sentinel([('localhost', 26379)])
>>> Redis(client=sentinel.master_for('mymaster'))
>>> Redis(client_factory=lambda: Sentinel([('localhost', 26379)]).master_for('mymaster'))

"""

client: RedisClient | RedisCluster = dataclasses.field(repr=False)
client: RedisClient | RedisCluster | None = dataclasses.field(
repr=False, default=None
)
client_factory: typing.Callable[[], RedisClient | RedisCluster] | None = (
dataclasses.field(repr=False, default=None)
)

def __post_init__(self):
# Validate that exactly one of client or client_factory is provided
if self.client is not None and self.client_factory is not None:
raise ValueError(
"Provide exactly one of `client` or `client_factory`, not both."
)
if self.client is None and self.client_factory is None:
raise ValueError(
"You must provide either `client` (deprecated) or `client_factory` "
"when instantiating `Redis`."
)

# Emit deprecation warning if using the old client parameter
if self.client is not None:
warnings.warn(
"The `client` argument is deprecated and will be removed in a future version. "
"Please use `client_factory` instead.",
DeprecationWarning,
stacklevel=2,
)

async def run(self):
# Create a new client for this health check request
if self.client_factory is not None:
client = self.client_factory()
should_close = True
else:
# Use the deprecated client parameter (user manages lifecycle)
client = self.client
should_close = False

logger.debug("Pinging Redis client...")
try:
await self.client.ping()
await client.ping()
except ConnectionRefusedError as e:
raise ServiceUnavailable(
"Unable to connect to Redis: Connection was refused."
Expand All @@ -64,4 +100,6 @@ async def run(self):
else:
logger.debug("Connection established. Redis is healthy.")
finally:
await self.client.aclose()
# Only close clients created by client_factory
if should_close:
await client.aclose()
112 changes: 99 additions & 13 deletions tests/contrib/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,135 @@ class TestRedis:

@pytest.mark.asyncio
async def test_redis__ok(self):
"""Ping Redis successfully when using client parameter."""
"""Ping Redis successfully when using client_factory parameter."""
mock_client = mock.AsyncMock()
mock_client.ping.return_value = True

check = RedisHealthCheck(client=mock_client)
check = RedisHealthCheck(client_factory=lambda: mock_client)
result = await check.get_result()
assert result.error is None
mock_client.ping.assert_called_once()
mock_client.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__connection_refused(self):
"""Raise ServiceUnavailable when connection is refused."""
mock_client = mock.AsyncMock()
mock_client.ping.side_effect = ConnectionRefusedError("refused")

check = RedisHealthCheck(client=mock_client)
check = RedisHealthCheck(client_factory=lambda: mock_client)
result = await check.get_result()
assert result.error is not None
assert isinstance(result.error, ServiceUnavailable)
mock_client.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__timeout(self):
"""Raise ServiceUnavailable when connection times out."""
mock_client = mock.AsyncMock()
mock_client.ping.side_effect = RedisTimeoutError("timeout")

check = RedisHealthCheck(client=mock_client)
check = RedisHealthCheck(client_factory=lambda: mock_client)
result = await check.get_result()
assert result.error is not None
assert isinstance(result.error, ServiceUnavailable)
mock_client.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__connection_error(self):
"""Raise ServiceUnavailable when connection fails."""
mock_client = mock.AsyncMock()
mock_client.ping.side_effect = RedisConnectionError("connection error")

check = RedisHealthCheck(client=mock_client)
check = RedisHealthCheck(client_factory=lambda: mock_client)
result = await check.get_result()
assert result.error is not None
assert isinstance(result.error, ServiceUnavailable)
mock_client.aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__client_deprecated(self):
"""Verify DeprecationWarning is raised when using client parameter."""
mock_client = mock.AsyncMock()
mock_client.ping.return_value = True

with pytest.warns(
DeprecationWarning, match="client.*deprecated.*client_factory"
):
check = RedisHealthCheck(client=mock_client)

result = await check.get_result()
assert result.error is None
mock_client.ping.assert_called_once()
# User-provided client should NOT be closed by the health check
mock_client.aclose.assert_not_called()

@pytest.mark.asyncio
async def test_redis__factory_called_for_each_result(self):
"""Verify client_factory is called per result and each client is closed."""
call_count = 0
created_clients = []

def factory():
nonlocal call_count, created_clients
call_count += 1
client = mock.AsyncMock()
client.ping.return_value = True
created_clients.append(client)
return client

check = RedisHealthCheck(client_factory=factory)
# Factory should not be called eagerly during initialization
assert call_count == 0, "Factory should not be called during initialization"

# Each request should use a newly created client
result1 = await check.get_result()
assert result1.error is None
assert call_count == 1, "Factory should be called once for first request"

result2 = await check.get_result()
assert result2.error is None
assert call_count == 2, "Factory should be called again for second request"

# Ensure a distinct client was created and closed for each result
assert len(created_clients) == 2
assert created_clients[0] is not created_clients[1], (
"Each request should create a distinct client"
)
created_clients[0].aclose.assert_called_once()
created_clients[1].aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__client_not_closed_when_user_provided(self):
"""Verify user-provided client is NOT closed by health check."""
mock_client = mock.AsyncMock()
mock_client.ping.return_value = True

with pytest.warns(DeprecationWarning):
check = RedisHealthCheck(client=mock_client)

result = await check.get_result()
assert result.error is None
mock_client.ping.assert_called_once()
# User is responsible for closing their own client
mock_client.aclose.assert_not_called()

@pytest.mark.asyncio
async def test_redis__validation_both_params(self):
"""Verify error when both client and client_factory are provided."""
mock_client = mock.AsyncMock()
with pytest.raises(
ValueError, match="Provide exactly one of `client` or `client_factory`"
):
RedisHealthCheck(client=mock_client, client_factory=lambda: mock_client)

@pytest.mark.asyncio
async def test_redis__validation_neither_param(self):
"""Verify error when neither client nor client_factory is provided."""
with pytest.raises(
ValueError, match="You must provide either `client`.*or `client_factory`"
):
RedisHealthCheck()

@pytest.mark.integration
@pytest.mark.asyncio
Expand All @@ -71,11 +159,9 @@ async def test_redis__real_connection(self):

from redis.asyncio import Redis as RedisClient

client = RedisClient.from_url(redis_url)
check = RedisHealthCheck(client=client)
check = RedisHealthCheck(client_factory=lambda: RedisClient.from_url(redis_url))
result = await check.get_result()
assert result.error is None
await client.aclose()

@pytest.mark.integration
@pytest.mark.asyncio
Expand All @@ -97,11 +183,11 @@ async def test_redis__real_sentinel(self):
host, port = node.strip().split(":")
sentinels.append((host, int(port)))

# Create Sentinel and get master client
sentinel = Sentinel(sentinels)
master = sentinel.master_for(service_name)
# Create factory that returns Sentinel master client
def factory():
sentinel = Sentinel(sentinels)
return sentinel.master_for(service_name)

# Use the unified Redis check with the master client
check = RedisHealthCheck(client=master)
check = RedisHealthCheck(client_factory=factory)
result = await check.get_result()
assert result.error is None