Skip to content

Commit a536a5a

Browse files
committed
Add Pool.on_acquire hook with AcquireEvent
Adds an optional `on_acquire` callback to `Pool` / `create_pool`, mirroring the existing setup/init/reset style. The callback is invoked synchronously with an `AcquireEvent(wait_seconds, size, idle, max_size)` after every successful `Pool.acquire` dispatch. Lets applications detect pool saturation (long wait, idle == 0) without subclassing the pool or wrapping every callsite. Callback exceptions are logged and suppressed. No behavior change when unused.
1 parent db8ecc2 commit a536a5a

3 files changed

Lines changed: 112 additions & 6 deletions

File tree

asyncpg/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .connection import connect, Connection # NOQA
1010
from .exceptions import * # NOQA
11-
from .pool import create_pool, Pool # NOQA
11+
from .pool import create_pool, Pool, AcquireEvent # NOQA
1212
from .protocol import Record # NOQA
1313
from .types import * # NOQA
1414

@@ -19,6 +19,6 @@
1919

2020

2121
__all__: tuple[str, ...] = (
22-
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
22+
'connect', 'create_pool', 'Pool', 'Record', 'Connection', 'AcquireEvent'
2323
)
2424
__all__ += exceptions.__all__ # NOQA

asyncpg/pool.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import asyncio
1010
from collections.abc import Awaitable, Callable
11+
import dataclasses
1112
import functools
1213
import inspect
1314
import logging
@@ -25,6 +26,19 @@
2526
logger = logging.getLogger(__name__)
2627

2728

29+
@dataclasses.dataclass(frozen=True)
30+
class AcquireEvent:
31+
"""Emitted by :meth:`Pool.acquire` on every successful dispatch.
32+
33+
.. versionadded:: 0.32.0
34+
"""
35+
36+
wait_seconds: float
37+
size: int
38+
idle: int
39+
max_size: int
40+
41+
2842
class PoolConnectionProxyMeta(type):
2943

3044
def __new__(
@@ -342,7 +356,8 @@ class Pool:
342356
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
343357
'_holders', '_initialized', '_initializing', '_closing',
344358
'_closed', '_connection_class', '_record_class', '_generation',
345-
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
359+
'_setup', '_max_queries', '_max_inactive_connection_lifetime',
360+
'_on_acquire',
346361
)
347362

348363
def __init__(self, *connect_args,
@@ -357,6 +372,8 @@ def __init__(self, *connect_args,
357372
loop,
358373
connection_class,
359374
record_class,
375+
on_acquire: Optional[
376+
Callable[[AcquireEvent], None]] = None,
360377
**connect_kwargs):
361378

362379
if len(connect_args) > 1:
@@ -399,6 +416,8 @@ def __init__(self, *connect_args,
399416
'record_class is expected to be a subclass of '
400417
'asyncpg.Record, got {!r}'.format(record_class))
401418

419+
self._on_acquire = on_acquire
420+
402421
self._minsize = min_size
403422
self._maxsize = max_size
404423

@@ -892,11 +911,29 @@ async def _acquire_impl():
892911
raise exceptions.InterfaceError('pool is closing')
893912
self._check_init()
894913

914+
cb = self._on_acquire
915+
if cb is None:
916+
if timeout is None:
917+
return await _acquire_impl()
918+
return await compat.wait_for(_acquire_impl(), timeout=timeout)
919+
920+
started = time.monotonic()
895921
if timeout is None:
896-
return await _acquire_impl()
922+
proxy = await _acquire_impl()
897923
else:
898-
return await compat.wait_for(
899-
_acquire_impl(), timeout=timeout)
924+
proxy = await compat.wait_for(_acquire_impl(), timeout=timeout)
925+
event = AcquireEvent(
926+
wait_seconds=time.monotonic() - started,
927+
size=self.get_size(),
928+
idle=self.get_idle_size(),
929+
max_size=self._maxsize,
930+
)
931+
try:
932+
cb(event)
933+
except Exception:
934+
logger.exception(
935+
'asyncpg on_acquire callback raised; suppressing')
936+
return proxy
900937

901938
async def release(self, connection, *, timeout=None):
902939
"""Release a database connection back to the pool.
@@ -1084,6 +1121,8 @@ def create_pool(dsn=None, *,
10841121
loop=None,
10851122
connection_class=connection.Connection,
10861123
record_class=protocol.Record,
1124+
on_acquire: Optional[
1125+
Callable[[AcquireEvent], None]] = None,
10871126
**connect_kwargs):
10881127
r"""Create a connection pool.
10891128
@@ -1230,6 +1269,16 @@ def create_pool(dsn=None, *,
12301269
12311270
.. versionchanged:: 0.30.0
12321271
Added the *connect* and *reset* parameters.
1272+
1273+
:param on_acquire:
1274+
Synchronous callback invoked with an :class:`AcquireEvent` after
1275+
every successful :meth:`Pool.acquire` dispatch. ``wait_seconds``
1276+
is wall-clock time spent inside :meth:`Pool.acquire` (queue wait
1277+
plus any reconnect or ``setup`` callback). Exceptions are
1278+
logged and suppressed.
1279+
1280+
.. versionchanged:: 0.32.0
1281+
Added the *on_acquire* parameter.
12331282
"""
12341283
return Pool(
12351284
dsn,
@@ -1244,5 +1293,6 @@ def create_pool(dsn=None, *,
12441293
init=init,
12451294
reset=reset,
12461295
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
1296+
on_acquire=on_acquire,
12471297
**connect_kwargs,
12481298
)

tests/test_pool.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,62 @@ async def worker():
10041004
conn = await pool.acquire(timeout=0.1)
10051005
await pool.release(conn)
10061006

1007+
async def test_pool_on_acquire_reports_saturation_wait(self):
1008+
events = []
1009+
pool = await self.create_pool(
1010+
database='postgres',
1011+
min_size=1,
1012+
max_size=1,
1013+
on_acquire=events.append,
1014+
)
1015+
try:
1016+
holder_acquired = asyncio.Event()
1017+
release_holder = asyncio.Event()
1018+
1019+
async def holder():
1020+
async with pool.acquire():
1021+
holder_acquired.set()
1022+
await release_holder.wait()
1023+
1024+
async def waiter():
1025+
await holder_acquired.wait()
1026+
async with pool.acquire() as con:
1027+
await con.fetchval('SELECT 1')
1028+
1029+
holder_task = self.loop.create_task(holder())
1030+
waiter_task = self.loop.create_task(waiter())
1031+
await holder_acquired.wait()
1032+
await asyncio.sleep(0.15)
1033+
release_holder.set()
1034+
await asyncio.gather(holder_task, waiter_task)
1035+
finally:
1036+
await pool.close()
1037+
1038+
self.assertEqual(len(events), 2)
1039+
for ev in events:
1040+
self.assertEqual(ev.max_size, 1)
1041+
self.assertGreaterEqual(ev.wait_seconds, 0)
1042+
self.assertGreaterEqual(
1043+
max(ev.wait_seconds for ev in events), 0.1)
1044+
1045+
async def test_pool_on_acquire_not_fired_on_timeout(self):
1046+
events = []
1047+
pool = await self.create_pool(
1048+
database='postgres',
1049+
min_size=1,
1050+
max_size=1,
1051+
on_acquire=events.append,
1052+
)
1053+
try:
1054+
async with pool.acquire():
1055+
with self.assertRaises(asyncio.TimeoutError):
1056+
await pool.acquire(timeout=0.1)
1057+
finally:
1058+
await pool.close()
1059+
1060+
# one event for the outer successful acquire, none for the timeout
1061+
self.assertEqual(len(events), 1)
1062+
10071063

10081064
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
10091065
class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):

0 commit comments

Comments
 (0)