Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions health_check/contrib/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,37 @@ class Redis(HealthCheck):
dataclasses.field(repr=False, default=None)
)

async def run(self):
if self.client_factory:
client = self.client_factory()
else:
def __post_init__(self):
# Validate that exactly one of client or client_factory is provided
if self.client is not None and self.client_factory is not None:
raise ValueError(
"Provide exactly one of `client` or `client_factory`, not both."
)
if self.client is None and self.client_factory is None:
raise ValueError(
"You must provide either `client` (deprecated) or `client_factory` "
"when instantiating `Redis`."
)

# Emit deprecation warning if using the old client parameter
if self.client is not None:
warnings.warn(
"The `client` argument is deprecated and will be removed in a future version. "
"Please use `client_factory` instead.",
DeprecationWarning,
stacklevel=2,
)

async def run(self):
# Create a new client for this health check request
if self.client_factory is not None:
client = self.client_factory()
should_close = True
else:
# Use the deprecated client parameter (user manages lifecycle)
client = self.client
should_close = False

logger.debug("Pinging Redis client...")
try:
await client.ping()
Expand All @@ -80,4 +100,6 @@ async def run(self):
else:
logger.debug("Connection established. Redis is healthy.")
finally:
await client.aclose()
# Only close clients created by client_factory
if should_close:
await client.aclose()
54 changes: 41 additions & 13 deletions tests/contrib/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,37 +79,47 @@ async def test_redis__client_deprecated(self):
result = await check.get_result()
assert result.error is None
mock_client.ping.assert_called_once()
mock_client.aclose.assert_called_once()
# User-provided client should NOT be closed by the health check
mock_client.aclose.assert_not_called()

@pytest.mark.asyncio
async def test_redis__factory_called_once_in_init(self):
"""Verify client_factory is called once during initialization."""
async def test_redis__factory_called_for_each_result(self):
"""Verify client_factory is called per result and each client is closed."""
call_count = 0
created_clients = []

def factory():
nonlocal call_count
nonlocal call_count, created_clients
call_count += 1
client = mock.AsyncMock()
client.ping.return_value = True
created_clients.append(client)
return client

check = RedisHealthCheck(client_factory=factory)
assert call_count == 1, "Factory should be called once during initialization"
# Factory should not be called eagerly during initialization
assert call_count == 0, "Factory should not be called during initialization"

# Multiple requests reuse the same client
# Each request should use a newly created client
result1 = await check.get_result()
assert result1.error is None
assert call_count == 1, (
"Factory should not be called again for subsequent requests"
)
assert call_count == 1, "Factory should be called once for first request"

result2 = await check.get_result()
assert result2.error is None
assert call_count == 1, "Factory should still not be called again"
assert call_count == 2, "Factory should be called again for second request"

# Ensure a distinct client was created and closed for each result
assert len(created_clients) == 2
assert created_clients[0] is not created_clients[1], (
"Each request should create a distinct client"
)
created_clients[0].aclose.assert_called_once()
created_clients[1].aclose.assert_called_once()

@pytest.mark.asyncio
async def test_redis__client_always_closed(self):
"""Verify client is always closed after health check."""
async def test_redis__client_not_closed_when_user_provided(self):
"""Verify user-provided client is NOT closed by health check."""
mock_client = mock.AsyncMock()
mock_client.ping.return_value = True

Expand All @@ -119,7 +129,25 @@ async def test_redis__client_always_closed(self):
result = await check.get_result()
assert result.error is None
mock_client.ping.assert_called_once()
mock_client.aclose.assert_called_once()
# User is responsible for closing their own client
mock_client.aclose.assert_not_called()

@pytest.mark.asyncio
async def test_redis__validation_both_params(self):
"""Verify error when both client and client_factory are provided."""
mock_client = mock.AsyncMock()
with pytest.raises(
ValueError, match="Provide exactly one of `client` or `client_factory`"
):
RedisHealthCheck(client=mock_client, client_factory=lambda: mock_client)

@pytest.mark.asyncio
async def test_redis__validation_neither_param(self):
"""Verify error when neither client nor client_factory is provided."""
with pytest.raises(
ValueError, match="You must provide either `client`.*or `client_factory`"
):
RedisHealthCheck()

@pytest.mark.integration
@pytest.mark.asyncio
Expand Down