Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions aiosqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,15 @@

"""asyncio bridge to the standard sqlite3 module"""

from sqlite3 import ( # pylint: disable=redefined-builtin
DatabaseError,
Error,
IntegrityError,
NotSupportedError,
OperationalError,
paramstyle,
ProgrammingError,
register_adapter,
register_converter,
Row,
sqlite_version,
sqlite_version_info,
Warning,
)
from sqlite3 import (DatabaseError, Error, # pylint: disable=redefined-builtin
IntegrityError, NotSupportedError, OperationalError,
ProgrammingError, Row, Warning, paramstyle,
register_adapter, register_converter, sqlite_version,
sqlite_version_info)

__author__ = "Amethyst Reese"
from .__version__ import __version__
from .core import connect, Connection, Cursor
from .core import Connection, Cursor, connect

__all__ = [
"__version__",
Expand Down
4 changes: 2 additions & 2 deletions aiosqlite/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Licensed under the MIT license


from collections.abc import Coroutine, Generator
from collections.abc import Callable, Coroutine, Generator
from contextlib import AbstractAsyncContextManager
from functools import wraps
from typing import Any, Callable, TypeVar
from typing import Any, TypeVar

from .cursor import Cursor

Expand Down
45 changes: 23 additions & 22 deletions aiosqlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
"""
Core implementation of aiosqlite proxies
"""
from __future__ import annotations

import asyncio
import logging
import sqlite3
from collections.abc import AsyncIterator, Generator, Iterable
from collections.abc import AsyncIterator, Callable, Generator, Iterable
from functools import partial
from pathlib import Path
from queue import Empty, Queue, SimpleQueue
from threading import Thread
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal
from warnings import warn

from .context import contextmanager
Expand All @@ -26,7 +27,7 @@
LOG = logging.getLogger("aiosqlite")


IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
IsolationLevel = Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE", None]


def set_result(fut: asyncio.Future, result: Any) -> None:
Expand All @@ -42,7 +43,7 @@ def set_exception(fut: asyncio.Future, e: BaseException) -> None:


_STOP_RUNNING_SENTINEL = object()
_TxQueue = SimpleQueue[tuple[Optional[asyncio.Future], Callable[[], Any]]]
_TxQueue = SimpleQueue[tuple[asyncio.Future | None, Callable[[], Any]]]


def _connection_worker_thread(tx: _TxQueue):
Expand Down Expand Up @@ -80,10 +81,10 @@ def __init__(
self,
connector: Callable[[], sqlite3.Connection],
iter_chunk_size: int,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
self._running = True
self._connection: Optional[sqlite3.Connection] = None
self._connection: sqlite3.Connection | None = None
self._connector = connector
self._tx: _TxQueue = SimpleQueue()
self._iter_chunk_size = iter_chunk_size
Expand Down Expand Up @@ -113,7 +114,7 @@ def __del__(self):
# be finalized by its own __del__.
self.stop()

def stop(self) -> Optional[asyncio.Future]:
def stop(self) -> asyncio.Future | None:
"""Stop the background thread. Prefer `async with` or `await close()`"""
self._running = False

Expand All @@ -138,7 +139,7 @@ def _conn(self) -> sqlite3.Connection:

return self._connection

def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
def _execute_insert(self, sql: str, parameters: Any) -> sqlite3.Row | None:
cursor = self._conn.execute(sql, parameters)
cursor.execute("SELECT last_insert_rowid()")
return cursor.fetchone()
Expand Down Expand Up @@ -215,7 +216,7 @@ async def close(self) -> None:

@contextmanager
async def execute(
self, sql: str, parameters: Optional[Iterable[Any]] = None
self, sql: str, parameters: Iterable[Any] | None = None
) -> Cursor:
"""Helper to create a cursor and execute the given query."""
if parameters is None:
Expand All @@ -225,16 +226,16 @@ async def execute(

@contextmanager
async def execute_insert(
self, sql: str, parameters: Optional[Iterable[Any]] = None
) -> Optional[sqlite3.Row]:
self, sql: str, parameters: Iterable[Any] | None = None
) -> sqlite3.Row | None:
"""Helper to insert and get the last_insert_rowid."""
if parameters is None:
parameters = []
return await self._execute(self._execute_insert, sql, parameters)

@contextmanager
async def execute_fetchall(
self, sql: str, parameters: Optional[Iterable[Any]] = None
self, sql: str, parameters: Iterable[Any] | None = None
) -> Iterable[sqlite3.Row]:
"""Helper to execute a query and return all the data."""
if parameters is None:
Expand Down Expand Up @@ -286,19 +287,19 @@ def in_transaction(self) -> bool:
return self._conn.in_transaction

@property
def isolation_level(self) -> Optional[str]:
def isolation_level(self) -> str | None:
return self._conn.isolation_level

@isolation_level.setter
def isolation_level(self, value: IsolationLevel) -> None:
self._conn.isolation_level = value

@property
def row_factory(self) -> Optional[type]:
def row_factory(self) -> type | None:
return self._conn.row_factory

@row_factory.setter
def row_factory(self, factory: Optional[type]) -> None:
def row_factory(self, factory: type | None) -> None:
self._conn.row_factory = factory

@property
Expand All @@ -320,15 +321,15 @@ async def load_extension(self, path: str):
await self._execute(self._conn.load_extension, path) # type: ignore

async def set_progress_handler(
self, handler: Callable[[], Optional[int]], n: int
self, handler: Callable[[], int | None], n: int
) -> None:
await self._execute(self._conn.set_progress_handler, handler, n)

async def set_trace_callback(self, handler: Callable) -> None:
await self._execute(self._conn.set_trace_callback, handler)

async def set_authorizer(
self, authorizer_callback: Optional[AuthorizerCallback]
self, authorizer_callback: AuthorizerCallback | None
) -> None:
"""
Set an authorizer callback to control database access.
Expand Down Expand Up @@ -399,7 +400,7 @@ def dumper():

while True:
try:
line: Optional[str] = dump_queue.get_nowait()
line: str | None = dump_queue.get_nowait()
if line is None:
break
yield line
Expand All @@ -415,10 +416,10 @@ def dumper():

async def backup(
self,
target: Union["Connection", sqlite3.Connection],
target: "Connection" | sqlite3.Connection,
*,
pages: int = 0,
progress: Optional[Callable[[int, int, int], None]] = None,
progress: Callable[[int, int, int], None] | None = None,
name: str = "main",
sleep: float = 0.250,
) -> None:
Expand All @@ -441,10 +442,10 @@ async def backup(


def connect(
database: Union[str, Path],
database: str | Path,
*,
iter_chunk_size=64,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop | None = None,
**kwargs: Any,
) -> Connection:
"""Create and return a connection proxy to the sqlite database."""
Expand Down
27 changes: 18 additions & 9 deletions aiosqlite/cursor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright Amethyst Reese
# Licensed under the MIT license

from __future__ import annotations

import sqlite3
from collections.abc import AsyncIterator, Iterable
from typing import Any, Callable, Optional, TYPE_CHECKING
from collections.abc import AsyncIterator, Callable, Iterable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from types import TracebackType

from .core import Connection


Expand All @@ -32,7 +36,7 @@ async def _execute(self, fn, *args, **kwargs):
return await self._conn._execute(fn, *args, **kwargs)

async def execute(
self, sql: str, parameters: Optional[Iterable[Any]] = None
self, sql: str, parameters: Iterable[Any] | None = None
) -> "Cursor":
"""Execute the given query."""
if parameters is None:
Expand All @@ -52,11 +56,11 @@ async def executescript(self, sql_script: str) -> "Cursor":
await self._execute(self._cursor.executescript, sql_script)
return self

async def fetchone(self) -> Optional[sqlite3.Row]:
async def fetchone(self) -> sqlite3.Row | None:
"""Fetch a single row."""
return await self._execute(self._cursor.fetchone)

async def fetchmany(self, size: Optional[int] = None) -> Iterable[sqlite3.Row]:
async def fetchmany(self, size: int | None = None) -> Iterable[sqlite3.Row]:
"""Fetch up to `cursor.arraysize` number of rows."""
args: tuple[int, ...] = ()
if size is not None:
Expand All @@ -76,7 +80,7 @@ def rowcount(self) -> int:
return self._cursor.rowcount

@property
def lastrowid(self) -> Optional[int]:
def lastrowid(self) -> int | None:
return self._cursor.lastrowid

@property
Expand All @@ -92,11 +96,11 @@ def description(self) -> tuple[tuple[str, None, None, None, None, None, None], .
return self._cursor.description

@property
def row_factory(self) -> Optional[Callable[[sqlite3.Cursor, sqlite3.Row], object]]:
def row_factory(self) -> Callable[[sqlite3.Cursor, sqlite3.Row], object] | None:
return self._cursor.row_factory

@row_factory.setter
def row_factory(self, factory: Optional[type]) -> None:
def row_factory(self, factory: type | None) -> None:
self._cursor.row_factory = factory

@property
Expand All @@ -106,5 +110,10 @@ def connection(self) -> sqlite3.Connection:
async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: "TracebackType" | None
):
await self.close()
2 changes: 1 addition & 1 deletion aiosqlite/tests/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import string
import tempfile
import time

from unittest import IsolatedAsyncioTestCase as TestCase

import aiosqlite

from .smoke import setup_logger

TEST_DB = ":memory:"
Expand Down
1 change: 1 addition & 0 deletions aiosqlite/tests/smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest.mock import patch

import aiosqlite

from .helpers import setup_logger


Expand Down