Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions aiosqlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def set_exception(fut: asyncio.Future, e: BaseException) -> None:


_STOP_RUNNING_SENTINEL = object()
_TxQueue = SimpleQueue[tuple[Optional[asyncio.Future], Callable[[], Any]]]


def _connection_worker_thread(
tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]],
):
def _connection_worker_thread(tx: _TxQueue):
"""
Execute function calls on a separate thread.

Expand All @@ -57,21 +56,23 @@ def _connection_worker_thread(
# even after connection is closed (so we can finalize all
# futures)

tx_item = tx.get()
future, function = tx_item
future, function = tx.get()

try:
LOG.debug("executing %s", function)
result = function()

if future:
future.get_loop().call_soon_threadsafe(set_result, future, result)
LOG.debug("operation %s completed", function)
future.get_loop().call_soon_threadsafe(set_result, future, result)

if result is _STOP_RUNNING_SENTINEL:
break

except BaseException as e: # noqa B036
LOG.debug("returning exception %s", e)
future.get_loop().call_soon_threadsafe(set_exception, future, e)
if future:
future.get_loop().call_soon_threadsafe(set_exception, future, e)


class Connection:
Expand All @@ -84,7 +85,7 @@ def __init__(
self._running = True
self._connection: Optional[sqlite3.Connection] = None
self._connector = connector
self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
self._tx: _TxQueue = SimpleQueue()
self._iter_chunk_size = iter_chunk_size
self._thread = Thread(target=_connection_worker_thread, args=(self._tx,))

Expand All @@ -94,14 +95,40 @@ def __init__(
DeprecationWarning,
)

def _stop_running(self) -> asyncio.Future:
def __del__(self):
if self._connection is None:
return

warn(
(
f"{self!r} was deleted before being closed. "
"Please use 'async with' or '.close()' to close the connection properly."
),
ResourceWarning,
stacklevel=1,
)

# Don't try to be creative here, the event loop may have already been closed.
# Simply stop the worker thread, and let the underlying sqlite3 connection
# be finalized by its own __del__.
self.stop()

def stop(self) -> Optional[asyncio.Future]:
"""Stop the background thread. Prefer `async with` or `await close()`"""
self._running = False

function = partial(lambda: _STOP_RUNNING_SENTINEL)
future = asyncio.get_event_loop().create_future()
def close_and_stop():
if self._connection is not None:
self._connection.close()
self._connection = None
return _STOP_RUNNING_SENTINEL

self._tx.put_nowait((future, function))
try:
future = asyncio.get_event_loop().create_future()
except Exception:
future = None

self._tx.put_nowait((future, close_and_stop))
return future

@property
Expand Down Expand Up @@ -140,7 +167,7 @@ async def _connect(self) -> "Connection":
self._tx.put_nowait((future, self._connector))
self._connection = await future
except BaseException:
await self._stop_running()
self.stop()
self._connection = None
raise

Expand Down Expand Up @@ -181,8 +208,10 @@ async def close(self) -> None:
LOG.info("exception occurred while closing connection")
raise
finally:
await self._stop_running()
self._connection = None
future = self.stop()
if future:
await future

@contextmanager
async def execute(
Expand Down Expand Up @@ -410,24 +439,6 @@ async def backup(
sleep=sleep,
)

def __del__(self):
if self._connection is None:
return

warn(
(
f"{self!r} was deleted before being closed. "
"Please use 'async with' or '.close()' to close the connection properly."
),
ResourceWarning,
stacklevel=1,
)

# Don't try to be creative here, the event loop may have already been closed.
# Simply stop the worker thread, and let the underlying sqlite3 connection
# be finalized by its own __del__.
self._stop_running()


def connect(
database: Union[str, Path],
Expand Down
19 changes: 18 additions & 1 deletion aiosqlite/tests/smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _raise_cancelled_error(*_, **__):
...
# Terminate the thread here if the test fails to have a clear error.
if connection._running:
connection._stop_running()
connection.stop()
raise AssertionError("connection thread was not stopped")

async def test_iterdump(self):
Expand Down Expand Up @@ -518,3 +518,20 @@ async def test_emits_warning_when_left_open(self):
ResourceWarning, r".*was deleted before being closed.*"
):
del db

async def test_stop_without_close(self):
db = await aiosqlite.connect(":memory:")
await db.stop()

def test_stop_after_event_loop_closed(self):
db = None

async def inner():
nonlocal db
db = await aiosqlite.connect(":memory:")

loop = asyncio.new_event_loop()
loop.run_until_complete(inner())
loop.close()

db.stop()