Skip to content

Commit 8ac45fe

Browse files
authored
Merge branch 'master' into feature/driver-info
2 parents f2df255 + 12c9a38 commit 8ac45fe

File tree

8 files changed

+410
-16
lines changed

8 files changed

+410
-16
lines changed

redis/asyncio/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,11 @@ def __init__(
249249
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
250250
ssl_ca_certs: Optional[str] = None,
251251
ssl_ca_data: Optional[str] = None,
252+
ssl_ca_path: Optional[str] = None,
252253
ssl_check_hostname: bool = True,
253254
ssl_min_version: Optional[TLSVersion] = None,
254255
ssl_ciphers: Optional[str] = None,
256+
ssl_password: Optional[str] = None,
255257
max_connections: Optional[int] = None,
256258
single_connection_client: bool = False,
257259
health_check_interval: int = 0,
@@ -371,9 +373,11 @@ def __init__(
371373
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
372374
"ssl_ca_certs": ssl_ca_certs,
373375
"ssl_ca_data": ssl_ca_data,
376+
"ssl_ca_path": ssl_ca_path,
374377
"ssl_check_hostname": ssl_check_hostname,
375378
"ssl_min_version": ssl_min_version,
376379
"ssl_ciphers": ssl_ciphers,
380+
"ssl_password": ssl_password,
377381
}
378382
)
379383
# This arg only used if no pool is passed in

redis/asyncio/connection.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,11 @@ def __init__(
850850
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
851851
ssl_ca_certs: Optional[str] = None,
852852
ssl_ca_data: Optional[str] = None,
853+
ssl_ca_path: Optional[str] = None,
853854
ssl_check_hostname: bool = True,
854855
ssl_min_version: Optional[TLSVersion] = None,
855856
ssl_ciphers: Optional[str] = None,
857+
ssl_password: Optional[str] = None,
856858
**kwargs,
857859
):
858860
if not SSL_AVAILABLE:
@@ -866,9 +868,11 @@ def __init__(
866868
exclude_verify_flags=ssl_exclude_verify_flags,
867869
ca_certs=ssl_ca_certs,
868870
ca_data=ssl_ca_data,
871+
ca_path=ssl_ca_path,
869872
check_hostname=ssl_check_hostname,
870873
min_version=ssl_min_version,
871874
ciphers=ssl_ciphers,
875+
password=ssl_password,
872876
)
873877
super().__init__(**kwargs)
874878

@@ -923,10 +927,12 @@ class RedisSSLContext:
923927
"exclude_verify_flags",
924928
"ca_certs",
925929
"ca_data",
930+
"ca_path",
926931
"context",
927932
"check_hostname",
928933
"min_version",
929934
"ciphers",
935+
"password",
930936
)
931937

932938
def __init__(
@@ -938,9 +944,11 @@ def __init__(
938944
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
939945
ca_certs: Optional[str] = None,
940946
ca_data: Optional[str] = None,
947+
ca_path: Optional[str] = None,
941948
check_hostname: bool = False,
942949
min_version: Optional[TLSVersion] = None,
943950
ciphers: Optional[str] = None,
951+
password: Optional[str] = None,
944952
):
945953
if not SSL_AVAILABLE:
946954
raise RedisError("Python wasn't built with SSL support")
@@ -965,11 +973,13 @@ def __init__(
965973
self.exclude_verify_flags = exclude_verify_flags
966974
self.ca_certs = ca_certs
967975
self.ca_data = ca_data
976+
self.ca_path = ca_path
968977
self.check_hostname = (
969978
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
970979
)
971980
self.min_version = min_version
972981
self.ciphers = ciphers
982+
self.password = password
973983
self.context: Optional[SSLContext] = None
974984

975985
def get(self) -> SSLContext:
@@ -983,10 +993,16 @@ def get(self) -> SSLContext:
983993
if self.exclude_verify_flags:
984994
for flag in self.exclude_verify_flags:
985995
context.verify_flags &= ~flag
986-
if self.certfile and self.keyfile:
987-
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
988-
if self.ca_certs or self.ca_data:
989-
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
996+
if self.certfile or self.keyfile:
997+
context.load_cert_chain(
998+
certfile=self.certfile,
999+
keyfile=self.keyfile,
1000+
password=self.password,
1001+
)
1002+
if self.ca_certs or self.ca_data or self.ca_path:
1003+
context.load_verify_locations(
1004+
cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
1005+
)
9901006
if self.min_version is not None:
9911007
context.minimum_version = self.min_version
9921008
if self.ciphers is not None:
@@ -1239,16 +1255,17 @@ def can_get_connection(self) -> bool:
12391255
version="5.3.0",
12401256
)
12411257
async def get_connection(self, command_name=None, *keys, **options):
1258+
"""Get a connected connection from the pool"""
12421259
async with self._lock:
1243-
"""Get a connected connection from the pool"""
12441260
connection = self.get_available_connection()
1245-
try:
1246-
await self.ensure_connection(connection)
1247-
except BaseException:
1248-
await self.release(connection)
1249-
raise
12501261

1251-
return connection
1262+
# We now perform the connection check outside of the lock.
1263+
try:
1264+
await self.ensure_connection(connection)
1265+
return connection
1266+
except BaseException:
1267+
await self.release(connection)
1268+
raise
12521269

12531270
def get_available_connection(self):
12541271
"""Get a connection from the pool, without making sure it is connected"""

redis/client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,10 +1035,22 @@ def is_health_check_response(self, response) -> bool:
10351035
If there are no subscriptions redis responds to PING command with a
10361036
bulk response, instead of a multi-bulk with "pong" and the response.
10371037
"""
1038-
return response in [
1039-
self.health_check_response, # If there was a subscription
1040-
self.health_check_response_b, # If there wasn't
1041-
]
1038+
if self.encoder.decode_responses:
1039+
return (
1040+
response
1041+
in [
1042+
self.health_check_response, # If there is a subscription
1043+
self.HEALTH_CHECK_MESSAGE, # If there are no subscriptions and decode_responses=True
1044+
]
1045+
)
1046+
else:
1047+
return (
1048+
response
1049+
in [
1050+
self.health_check_response, # If there is a subscription
1051+
self.health_check_response_b, # If there isn't a subscription and decode_responses=False
1052+
]
1053+
)
10421054

10431055
def check_health(self) -> None:
10441056
conn = self.connection

redis/cluster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def parse_cluster_myshardid(resp, **options):
185185
"ssl",
186186
"ssl_ca_certs",
187187
"ssl_ca_data",
188+
"ssl_ca_path",
188189
"ssl_certfile",
189190
"ssl_cert_reqs",
190191
"ssl_include_verify_flags",
@@ -2207,7 +2208,8 @@ def _sharded_message_generator(self):
22072208

22082209
def _pubsubs_generator(self):
22092210
while True:
2210-
yield from self.node_pubsub_mapping.values()
2211+
current_nodes = list(self.node_pubsub_mapping.values())
2212+
yield from current_nodes
22112213

22122214
def get_sharded_message(
22132215
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None

tests/test_asyncio/test_connection_pool.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,81 @@ async def test_pool_disconnect(self, master_host):
222222
await pool.disconnect(inuse_connections=False)
223223
assert conn.is_connected
224224

225+
async def test_lock_not_held_during_connection_establishment(self):
226+
"""
227+
Test that the connection pool lock is not held during the
228+
ensure_connection call, which involves socket connection and handshake.
229+
This is important for performance under high load.
230+
"""
231+
lock_states = []
232+
233+
class SlowConnectConnection(DummyConnection):
234+
"""Connection that simulates slow connection establishment"""
235+
236+
async def connect(self):
237+
# Check if the pool's lock is held during connection
238+
# We access the pool through the outer scope
239+
lock_states.append(pool._lock.locked())
240+
# Simulate slow connection
241+
await asyncio.sleep(0.01)
242+
self._connected = True
243+
244+
async with self.get_pool(connection_class=SlowConnectConnection) as pool:
245+
# Get a connection - this should call connect() outside the lock
246+
connection = await pool.get_connection()
247+
248+
# Verify the lock was NOT held during connect
249+
assert len(lock_states) > 0, "connect() should have been called"
250+
assert lock_states[0] is False, (
251+
"Lock should not be held during connection establishment"
252+
)
253+
254+
await pool.release(connection)
255+
256+
async def test_concurrent_connection_acquisition_performance(self):
257+
"""
258+
Test that multiple concurrent connection acquisitions don't block
259+
each other during connection establishment.
260+
"""
261+
connection_delay = 0.05
262+
num_connections = 3
263+
264+
class SlowConnectConnection(DummyConnection):
265+
"""Connection that simulates slow connection establishment"""
266+
267+
async def connect(self):
268+
# Simulate slow connection (e.g., network latency, TLS handshake)
269+
await asyncio.sleep(connection_delay)
270+
self._connected = True
271+
272+
async with self.get_pool(
273+
connection_class=SlowConnectConnection, max_connections=10
274+
) as pool:
275+
# Start acquiring multiple connections concurrently
276+
start_time = asyncio.get_running_loop().time()
277+
278+
# Try to get connections concurrently
279+
connections = await asyncio.gather(
280+
*[pool.get_connection() for _ in range(num_connections)]
281+
)
282+
283+
elapsed_time = asyncio.get_running_loop().time() - start_time
284+
285+
# With proper lock handling, these should complete mostly in parallel
286+
# If the lock was held during connect(), it would take num_connections * connection_delay
287+
# With lock only during pop, it should take ~connection_delay (connections in parallel)
288+
# We allow 2.5x overhead for system variance
289+
max_allowed_time = connection_delay * 2.5
290+
assert elapsed_time < max_allowed_time, (
291+
f"Concurrent connections took {elapsed_time:.3f}s, "
292+
f"expected < {max_allowed_time:.3f}s. "
293+
f"This suggests lock was held during connection establishment."
294+
)
295+
296+
# Clean up
297+
for conn in connections:
298+
await pool.release(conn)
299+
225300

226301
class TestBlockingConnectionPool:
227302
@asynccontextmanager

tests/test_asyncio/test_pubsub.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,63 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis):
671671
await p.aclose()
672672

673673

674+
@pytest.mark.onlynoncluster
675+
class TestPubSubHealthCheckResponse:
676+
"""Tests for health check response validation with different decode_responses settings"""
677+
678+
async def test_health_check_response_decode_false_list_format(self, r: redis.Redis):
679+
"""Test health_check_response includes list format with decode_responses=False"""
680+
p = r.pubsub()
681+
# List format: [b"pong", b"redis-py-health-check"]
682+
assert [b"pong", b"redis-py-health-check"] in p.health_check_response
683+
await p.aclose()
684+
685+
async def test_health_check_response_decode_false_bytes_format(
686+
self, r: redis.Redis
687+
):
688+
"""Test health_check_response includes bytes format with decode_responses=False"""
689+
p = r.pubsub()
690+
# Bytes format: b"redis-py-health-check"
691+
assert b"redis-py-health-check" in p.health_check_response
692+
await p.aclose()
693+
694+
async def test_health_check_response_decode_true_list_format(self, create_redis):
695+
"""Test health_check_response includes list format with decode_responses=True"""
696+
r = await create_redis(decode_responses=True)
697+
p = r.pubsub()
698+
# List format: ["pong", "redis-py-health-check"]
699+
assert ["pong", "redis-py-health-check"] in p.health_check_response
700+
await p.aclose()
701+
await r.aclose()
702+
703+
async def test_health_check_response_decode_true_string_format(self, create_redis):
704+
"""Test health_check_response includes string format with decode_responses=True"""
705+
r = await create_redis(decode_responses=True)
706+
p = r.pubsub()
707+
# String format: "redis-py-health-check" (THE FIX!)
708+
assert "redis-py-health-check" in p.health_check_response
709+
await p.aclose()
710+
await r.aclose()
711+
712+
async def test_health_check_response_decode_false_excludes_string(
713+
self, r: redis.Redis
714+
):
715+
"""Test health_check_response excludes string format with decode_responses=False"""
716+
p = r.pubsub()
717+
# String format should NOT be in the list when decode_responses=False
718+
assert "redis-py-health-check" not in p.health_check_response
719+
await p.aclose()
720+
721+
async def test_health_check_response_decode_true_excludes_bytes(self, create_redis):
722+
"""Test health_check_response excludes bytes format with decode_responses=True"""
723+
r = await create_redis(decode_responses=True)
724+
p = r.pubsub()
725+
# Bytes format should NOT be in the list when decode_responses=True
726+
assert b"redis-py-health-check" not in p.health_check_response
727+
await p.aclose()
728+
await r.aclose()
729+
730+
674731
@pytest.mark.onlynoncluster
675732
class TestPubSubConnectionKilled:
676733
@skip_if_server_version_lt("3.0.0")

tests/test_asyncio/test_ssl.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,55 @@ def capture_context_create_default():
141141

142142
finally:
143143
await r.aclose()
144+
145+
async def test_ssl_ca_path_parameter(self, request):
146+
"""Test that ssl_ca_path parameter is properly passed to SSLConnection"""
147+
ssl_url = request.config.option.redis_ssl_url
148+
parsed_url = urlparse(ssl_url)
149+
150+
# Test with a mock ca_path directory
151+
test_ca_path = "/tmp/test_ca_certs"
152+
153+
r = redis.Redis(
154+
host=parsed_url.hostname,
155+
port=parsed_url.port,
156+
ssl=True,
157+
ssl_cert_reqs="none",
158+
ssl_ca_path=test_ca_path,
159+
)
160+
161+
try:
162+
# Get the connection to verify ssl_ca_path is passed through
163+
conn = r.connection_pool.make_connection()
164+
assert isinstance(conn, redis.SSLConnection)
165+
166+
# Verify the ca_path is stored in the SSL context
167+
assert conn.ssl_context.ca_path == test_ca_path
168+
finally:
169+
await r.aclose()
170+
171+
async def test_ssl_password_parameter(self, request):
172+
"""Test that ssl_password parameter is properly passed to SSLConnection"""
173+
ssl_url = request.config.option.redis_ssl_url
174+
parsed_url = urlparse(ssl_url)
175+
176+
# Test with a mock password for encrypted private key
177+
test_password = "test_key_password"
178+
179+
r = redis.Redis(
180+
host=parsed_url.hostname,
181+
port=parsed_url.port,
182+
ssl=True,
183+
ssl_cert_reqs="none",
184+
ssl_password=test_password,
185+
)
186+
187+
try:
188+
# Get the connection to verify ssl_password is passed through
189+
conn = r.connection_pool.make_connection()
190+
assert isinstance(conn, redis.SSLConnection)
191+
192+
# Verify the password is stored in the SSL context
193+
assert conn.ssl_context.password == test_password
194+
finally:
195+
await r.aclose()

0 commit comments

Comments
 (0)