Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added reconnect of connection and TaskTransactionContextManager #549

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ target/
tests/fixtures/my.cnf

.pytest_cache
/tests/test_reconnect.py
46 changes: 45 additions & 1 deletion aiomysql/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import warnings
import contextlib
import asyncio

from pymysql.err import (
Warning, Error, InterfaceError, DataError,
Expand All @@ -22,6 +23,14 @@
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL)

ERROR_CODES_FOR_RECONNECTING = [
1927, # ER_CONNECTION_KILLED
1184, # ER_NEW_ABORTING_CONNECTION
1152, # ER_ABORTING_CONNECTION,
2003, # Can't connect to MySQL server
2013, # Lost connection to MySQL server during query
]


class Cursor:
"""Cursor is used to interact with the database."""
Expand Down Expand Up @@ -236,11 +245,46 @@ async def execute(self, query, args=None):
if args is not None:
query = query % self._escape_args(args, conn)

await self._query(query)
try:
await self._query(query)

except asyncio.CancelledError:
raise

except Exception as main_error:
if not hasattr(main_error, 'args') or main_error.args[0] not in ERROR_CODES_FOR_RECONNECTING:
raise main_error

logger.error(main_error)
sleep_time_list = [3] * 20
sleep_time_list.insert(0, 1)
for attempt, sleep_time in enumerate(sleep_time_list):
try:
logger.warning('%s - Reconnecting to MySQL. Attempt %d of 21 for connection %s', conn._db, attempt + 1, id(conn))
await conn.ping()
logger.info('%s - Successfully reconnected to MySQL after error for connection %s', conn._db, id(conn))
await self._query(query)
break

except asyncio.CancelledError:
raise

except Exception as e:
if not hasattr(e, 'args') or e.args[0] not in ERROR_CODES_FOR_RECONNECTING:
break

logger.error(e)
await asyncio.sleep(sleep_time)

else:
logger.error('%s - Reconnecting to MySQL failed for connection %s', conn._db, id(conn))
raise main_error

self._executed = query
if self._echo:
logger.info(query)
logger.info("%r", args)

return self._rowcount

async def executemany(self, query, args):
Expand Down
1 change: 1 addition & 0 deletions aiomysql/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

# Name the logger after the package.
logger = logging.getLogger(__package__)
logger.setLevel(logging.WARNING)
100 changes: 76 additions & 24 deletions aiomysql/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import asyncio
import collections
import sys
import warnings

from pymysql import OperationalError

from .log import logger
from .connection import connect
from .utils import (_PoolContextManager, _PoolConnectionContextManager,
_PoolAcquireContextManager)
_PoolAcquireContextManager, TaskTransactionContextManager)


def create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1,
Expand Down Expand Up @@ -50,6 +54,7 @@ def __init__(self, minsize, maxsize, echo, pool_recycle, loop, **kwargs):
self._closed = False
self._echo = echo
self._recycle = pool_recycle
self._db = kwargs.get('db')

@property
def echo(self):
Expand All @@ -71,6 +76,10 @@ def size(self):
def freesize(self):
return len(self._free)

@property
def db_name(self):
return self._db

async def clear(self):
"""Close all free connections in pool."""
async with self._cond:
Expand Down Expand Up @@ -131,9 +140,32 @@ async def wait_closed(self):

def acquire(self):
"""Acquire free connection from the pool."""
if sys.version_info < (3, 7):
o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.Task.current_task())

else:
o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.current_task())

if o_transaction_context_manager:
return o_transaction_context_manager

coro = self._acquire()
return _PoolAcquireContextManager(coro, self)

def acquire_with_transaction(self):
"""Acquire free connection from the pool for transaction"""
if sys.version_info < (3, 7):
o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.Task.current_task())

else:
o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.current_task())

if o_transaction_context_manager:
return o_transaction_context_manager

coro = self._acquire()
return TaskTransactionContextManager(coro, self)

async def _acquire(self):
if self._closing:
raise RuntimeError("Cannot acquire connection after closing pool")
Expand All @@ -142,11 +174,12 @@ async def _acquire(self):
await self._fill_free_pool(True)
if self._free:
conn = self._free.popleft()
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
# assert not conn.closed, conn
# assert conn not in self._used, (conn, self._used)
self._used.add(conn)
return conn
else:
logger.debug('%s - All connections (%d) are busy. Waiting for release connection', self._db, self.freesize)
await self._cond.wait()

async def _fill_free_pool(self, override_min):
Expand All @@ -156,6 +189,7 @@ async def _fill_free_pool(self, override_min):
while n < free_size:
conn = self._free[-1]
if conn._reader.at_eof() or conn._reader.exception():
logger.debug('%s - Connection (%d) is removed from pool because of at_eof or exception', self._db, id(conn))
self._free.pop()
conn.close()

Expand All @@ -167,38 +201,56 @@ async def _fill_free_pool(self, override_min):
self._free.pop()
conn.close()

elif (self._recycle > -1 and
self._loop.time() - conn.last_usage > self._recycle):
elif self._recycle > -1 and self._loop.time() - conn.last_usage > self._recycle:
logger.debug('%s - Connection (%d) is removed from pool because of recycle time %d', self._db, id(conn), self._recycle)
self._free.pop()
conn.close()

else:
self._free.rotate()

n += 1

while self.size < self.minsize:
self._acquiring += 1
try:
conn = await connect(echo=self._echo, loop=self._loop,
**self._conn_kwargs)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
finally:
self._acquiring -= 1
await self.__create_new_connection()

if self._free:
return

if override_min and (not self.maxsize or self.size < self.maxsize):
self._acquiring += 1
await self.__create_new_connection()

async def __create_new_connection(self):
logger.debug('%s - Try to create new connection', self._db)
self._acquiring += 1
try:
try:
conn = await connect(echo=self._echo, loop=self._loop,
**self._conn_kwargs)
# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()
finally:
self._acquiring -= 1
conn = await connect(echo=self._echo, loop=self._loop, **self._conn_kwargs)

except OperationalError as error:
logger.error(error)
sleep_time_list = [3] * 20
for attempt, sleep_time in enumerate(sleep_time_list):
try:
logger.warning('%s - Connect to MySQL failed. Attempt %d of 20', self._db, attempt + 1)
conn = await connect(echo=self._echo, loop=self._loop, **self._conn_kwargs)
logger.info('%s - Successfully connect to MySQL after error', self._db)
break

except OperationalError as e:
logger.error(e)
await asyncio.sleep(sleep_time)

else:
logger.error('%s - Connect to MySQL failed', self._db)
raise error

# raise exception if pool is closing
self._free.append(conn)
self._cond.notify()

finally:
self._acquiring -= 1

async def _wakeup(self):
async with self._cond:
Expand All @@ -213,10 +265,10 @@ def release(self, conn):
fut.set_result(None)

if conn in self._terminated:
assert conn.closed, conn
# assert conn.closed, conn
self._terminated.remove(conn)
return fut
assert conn in self._used, (conn, self._used)
# assert conn in self._used, (conn, self._used)
self._used.remove(conn)
if not conn.closed:
in_trans = conn.get_transaction_status()
Expand Down
Loading