Skip to content

Commit 61bc0d7

Browse files
authored
Merge branch 'master' into ps_fix_runtime_error_for_sharded_pubsub
2 parents d842821 + 9205321 commit 61bc0d7

File tree

2 files changed

+83
-7
lines changed

2 files changed

+83
-7
lines changed

redis/asyncio/connection.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,16 +1215,17 @@ def can_get_connection(self) -> bool:
12151215
version="5.3.0",
12161216
)
12171217
async def get_connection(self, command_name=None, *keys, **options):
1218+
"""Get a connected connection from the pool"""
12181219
async with self._lock:
1219-
"""Get a connected connection from the pool"""
12201220
connection = self.get_available_connection()
1221-
try:
1222-
await self.ensure_connection(connection)
1223-
except BaseException:
1224-
await self.release(connection)
1225-
raise
12261221

1227-
return connection
1222+
# We now perform the connection check outside of the lock.
1223+
try:
1224+
await self.ensure_connection(connection)
1225+
return connection
1226+
except BaseException:
1227+
await self.release(connection)
1228+
raise
12281229

12291230
def get_available_connection(self):
12301231
"""Get a connection from the pool, without making sure it is connected"""

tests/test_asyncio/test_connection_pool.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,81 @@ async def test_pool_disconnect(self, master_host):
222222
await pool.disconnect(inuse_connections=False)
223223
assert conn.is_connected
224224

225+
async def test_lock_not_held_during_connection_establishment(self):
226+
"""
227+
Test that the connection pool lock is not held during the
228+
ensure_connection call, which involves socket connection and handshake.
229+
This is important for performance under high load.
230+
"""
231+
lock_states = []
232+
233+
class SlowConnectConnection(DummyConnection):
234+
"""Connection that simulates slow connection establishment"""
235+
236+
async def connect(self):
237+
# Check if the pool's lock is held during connection
238+
# We access the pool through the outer scope
239+
lock_states.append(pool._lock.locked())
240+
# Simulate slow connection
241+
await asyncio.sleep(0.01)
242+
self._connected = True
243+
244+
async with self.get_pool(connection_class=SlowConnectConnection) as pool:
245+
# Get a connection - this should call connect() outside the lock
246+
connection = await pool.get_connection()
247+
248+
# Verify the lock was NOT held during connect
249+
assert len(lock_states) > 0, "connect() should have been called"
250+
assert lock_states[0] is False, (
251+
"Lock should not be held during connection establishment"
252+
)
253+
254+
await pool.release(connection)
255+
256+
async def test_concurrent_connection_acquisition_performance(self):
257+
"""
258+
Test that multiple concurrent connection acquisitions don't block
259+
each other during connection establishment.
260+
"""
261+
connection_delay = 0.05
262+
num_connections = 3
263+
264+
class SlowConnectConnection(DummyConnection):
265+
"""Connection that simulates slow connection establishment"""
266+
267+
async def connect(self):
268+
# Simulate slow connection (e.g., network latency, TLS handshake)
269+
await asyncio.sleep(connection_delay)
270+
self._connected = True
271+
272+
async with self.get_pool(
273+
connection_class=SlowConnectConnection, max_connections=10
274+
) as pool:
275+
# Start acquiring multiple connections concurrently
276+
start_time = asyncio.get_running_loop().time()
277+
278+
# Try to get connections concurrently
279+
connections = await asyncio.gather(
280+
*[pool.get_connection() for _ in range(num_connections)]
281+
)
282+
283+
elapsed_time = asyncio.get_running_loop().time() - start_time
284+
285+
# With proper lock handling, these should complete mostly in parallel
286+
# If the lock was held during connect(), it would take num_connections * connection_delay
287+
# With lock only during pop, it should take ~connection_delay (connections in parallel)
288+
# We allow 2.5x overhead for system variance
289+
max_allowed_time = connection_delay * 2.5
290+
assert elapsed_time < max_allowed_time, (
291+
f"Concurrent connections took {elapsed_time:.3f}s, "
292+
f"expected < {max_allowed_time:.3f}s. "
293+
f"This suggests lock was held during connection establishment."
294+
)
295+
296+
# Clean up
297+
for conn in connections:
298+
await pool.release(conn)
299+
225300

226301
class TestBlockingConnectionPool:
227302
@asynccontextmanager

0 commit comments

Comments
 (0)