Skip to content

Commit f05878f

Browse files
committed
feat: async stuff!
1 parent c57ed22 commit f05878f

File tree

6 files changed

+249
-121
lines changed

6 files changed

+249
-121
lines changed

libsql_sqlalchemy/__init__.py

+6-113
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,10 @@
1-
import os
2-
import urllib.parse
3-
4-
from sqlalchemy import util
51
from sqlalchemy.dialects import registry as _registry
6-
from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
7-
8-
from libsql_experimental import Connection
2+
from .aiolibsql import SQLiteDialect_aiolibsql
3+
from .libsql import SQLiteDialect_libsql
94

105
__version__ = "0.1.0-pre"
116

12-
_registry.register("sqlite.libsql", "libsql_sqlalchemy", "SQLiteDialect_libsql")
13-
14-
15-
def _build_connection_url(url, query, secure):
16-
# sorting of keys is for unit test support
17-
query_str = urllib.parse.urlencode(sorted(query.items()))
18-
19-
if not url.host:
20-
if query_str:
21-
return f"{url.database}?{query_str}"
22-
return url.database
23-
elif secure: # yes, pop to remove
24-
scheme = "https"
25-
else:
26-
scheme = "http"
27-
28-
if url.username and url.password:
29-
netloc = f"{url.username}:{url.password}@{url.host}"
30-
elif url.username:
31-
netloc = f"{url.username}@{url.host}"
32-
else:
33-
netloc = url.host
34-
35-
if url.port:
36-
netloc += f":{url.port}"
37-
38-
return urllib.parse.urlunsplit(
39-
(
40-
scheme,
41-
netloc,
42-
url.database or "",
43-
query_str,
44-
"", # fragment
45-
)
46-
)
47-
48-
49-
class SQLiteDialect_libsql(SQLiteDialect_pysqlite):
50-
driver = "libsql"
51-
# need to be set explicitly
52-
supports_statement_cache = SQLiteDialect_pysqlite.supports_statement_cache
53-
54-
@classmethod
55-
def import_dbapi(cls):
56-
import libsql_experimental as libsql
57-
58-
return libsql
59-
60-
def on_connect(self):
61-
import libsql_experimental as libsql
62-
63-
sqlite3_connect = super().on_connect()
64-
65-
def connect(conn):
66-
# LibSQL: there is no support for create_function()
67-
if isinstance(conn, Connection):
68-
return
69-
return sqlite3_connect(conn)
70-
71-
return connect
72-
73-
def create_connect_args(self, url):
74-
pysqlite_args = (
75-
("uri", bool),
76-
("timeout", float),
77-
("isolation_level", str),
78-
("detect_types", int),
79-
("check_same_thread", bool),
80-
("cached_statements", int),
81-
("secure", bool), # LibSQL extra, selects between ws and wss
82-
)
83-
opts = url.query
84-
libsql_opts = {}
85-
for key, type_ in pysqlite_args:
86-
util.coerce_kw_type(opts, key, type_, dest=libsql_opts)
87-
88-
if url.host:
89-
libsql_opts["uri"] = True
90-
91-
if libsql_opts.get("uri", False):
92-
uri_opts = dict(opts)
93-
# here, we are actually separating the parameters that go to
94-
# sqlite3/pysqlite vs. those that go the SQLite URI. What if
95-
# two names conflict? again, this seems to be not the case right
96-
# now, and in the case that new names are added to
97-
# either side which overlap, again the sqlite3/pysqlite parameters
98-
# can be passed through connect_args instead of in the URL.
99-
# If SQLite native URIs add a parameter like "timeout" that
100-
# we already have listed here for the python driver, then we need
101-
# to adjust for that here.
102-
for key, type_ in pysqlite_args:
103-
uri_opts.pop(key, None)
104-
105-
secure = libsql_opts.pop("secure", False)
106-
connect_url = _build_connection_url(url, uri_opts, secure)
107-
else:
108-
connect_url = url.database or ":memory:"
109-
if connect_url != ":memory:":
110-
connect_url = os.path.abspath(connect_url)
111-
112-
libsql_opts.setdefault("check_same_thread", not self._is_url_file_db(url))
113-
114-
return ([connect_url], libsql_opts)
115-
116-
117-
dialect = SQLiteDialect_libsql
7+
_registry.register("sqlite.libsql", "libsql_sqlalchemy.libsql", "SQLiteDialect_libsql")
8+
_registry.register(
9+
"sqlite.aiolibsql", "libsql_sqlalchemy.aiolibsql", "SQLiteDialect_aiolibsql"
10+
)

libsql_sqlalchemy/aiolibsql.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from libsql_sqlalchemy.libsql import SQLiteDialect_libsql
2+
3+
4+
class SQLiteDialect_aiolibsql(SQLiteDialect_libsql):
5+
driver = "libsql"
6+
supports_statement_cache = SQLiteDialect_libsql.supports_statement_cache
7+
is_async = True
8+
9+
10+
dialect = SQLiteDialect_aiolibsql

libsql_sqlalchemy/libsql.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import urllib.parse
3+
4+
from sqlalchemy import util
5+
from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
6+
from libsql_experimental import Connection
7+
8+
9+
def _build_connection_url(url, query, secure):
10+
# sorting of keys is for unit test support
11+
query_str = urllib.parse.urlencode(sorted(query.items()))
12+
13+
if not url.host:
14+
if query_str:
15+
return f"{url.database}?{query_str}"
16+
return url.database
17+
elif secure: # yes, pop to remove
18+
scheme = "https"
19+
else:
20+
scheme = "http"
21+
22+
if url.username and url.password:
23+
netloc = f"{url.username}:{url.password}@{url.host}"
24+
elif url.username:
25+
netloc = f"{url.username}@{url.host}"
26+
else:
27+
netloc = url.host
28+
29+
if url.port:
30+
netloc += f":{url.port}"
31+
32+
return urllib.parse.urlunsplit(
33+
(
34+
scheme,
35+
netloc,
36+
url.database or "",
37+
query_str,
38+
"", # fragment
39+
)
40+
)
41+
42+
43+
class SQLiteDialect_libsql(SQLiteDialect_pysqlite):
44+
driver = "libsql"
45+
# need to be set explicitly
46+
supports_statement_cache = SQLiteDialect_pysqlite.supports_statement_cache
47+
48+
@classmethod
49+
def import_dbapi(cls):
50+
import libsql_experimental as libsql
51+
52+
return libsql
53+
54+
def on_connect(self):
55+
import libsql_experimental as libsql
56+
57+
sqlite3_connect = super().on_connect()
58+
59+
def connect(conn):
60+
# LibSQL: there is no support for create_function()
61+
if isinstance(conn, Connection):
62+
return
63+
return sqlite3_connect(conn)
64+
65+
return connect
66+
67+
def create_connect_args(self, url):
68+
pysqlite_args = (
69+
("uri", bool),
70+
("timeout", float),
71+
("isolation_level", str),
72+
("detect_types", int),
73+
("check_same_thread", bool),
74+
("cached_statements", int),
75+
("secure", bool), # LibSQL extra, selects between ws and wss
76+
)
77+
opts = url.query
78+
libsql_opts = {}
79+
for key, type_ in pysqlite_args:
80+
util.coerce_kw_type(opts, key, type_, dest=libsql_opts)
81+
82+
if url.host:
83+
libsql_opts["uri"] = True
84+
85+
if libsql_opts.get("uri", False):
86+
uri_opts = dict(opts)
87+
# here, we are actually separating the parameters that go to
88+
# sqlite3/pysqlite vs. those that go the SQLite URI. What if
89+
# two names conflict? again, this seems to be not the case right
90+
# now, and in the case that new names are added to
91+
# either side which overlap, again the sqlite3/pysqlite parameters
92+
# can be passed through connect_args instead of in the URL.
93+
# If SQLite native URIs add a parameter like "timeout" that
94+
# we already have listed here for the python driver, then we need
95+
# to adjust for that here.
96+
for key, type_ in pysqlite_args:
97+
uri_opts.pop(key, None)
98+
99+
secure = libsql_opts.pop("secure", False)
100+
connect_url = _build_connection_url(url, uri_opts, secure)
101+
else:
102+
connect_url = url.database or ":memory:"
103+
if connect_url != ":memory:":
104+
connect_url = os.path.abspath(connect_url)
105+
106+
libsql_opts.setdefault("check_same_thread", not self._is_url_file_db(url))
107+
108+
return ([connect_url], libsql_opts)
109+
110+
111+
dialect = SQLiteDialect_libsql

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ classifiers = [
2020
dependencies = [
2121
"libsql-experimental>=0.0.47",
2222
"sqlalchemy>=2.0.0",
23+
"greenlet>=3.0.3",
2324
]
2425

2526
[project.optional-dependencies]
2627
dev = [
2728
"pytest>=8.3.5",
2829
"pytest-cov>=4.1.0",
30+
"pytest-asyncio>=0.23.5",
2931
]
3032

3133
[project.urls]
@@ -35,4 +37,5 @@ Repository = "https://github.com/tursodatabase/libsql-sqlalchemy.git"
3537
Issues = "https://github.com/tursodatabase/libsql-sqlalchemy/issues"
3638

3739
[project.entry-points."sqlalchemy.dialects"]
38-
"sqlite.libsql" = "libsql_sqlalchemy:SQLiteDialect_libsql"
40+
"sqlite.libsql" = "libsql_sqlalchemy:SQLiteDialect_libsql"
41+
"sqlite.aiolibsql" = "libsql_sqlalchemy:SQLiteDialect_aiolibsql"

tests/test_async.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
import pytest_asyncio
3+
import asyncio
4+
from sqlalchemy import Column, Integer, String, text, select
5+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
6+
from sqlalchemy.orm import declarative_base
7+
from sqlalchemy.pool import AsyncAdaptedQueuePool
8+
9+
Base = declarative_base()
10+
11+
12+
class User(Base):
13+
__tablename__ = "users"
14+
15+
id = Column(Integer, primary_key=True)
16+
name = Column(String)
17+
email = Column(String)
18+
19+
20+
@pytest_asyncio.fixture
21+
async def engine():
22+
engine = create_async_engine(
23+
"sqlite+aiolibsql://",
24+
poolclass=AsyncAdaptedQueuePool,
25+
)
26+
yield engine
27+
await engine.dispose()
28+
29+
30+
@pytest_asyncio.fixture
31+
async def session(engine):
32+
async with engine.begin() as conn:
33+
await conn.run_sync(Base.metadata.create_all)
34+
35+
async_session = async_sessionmaker(
36+
engine, class_=AsyncSession, expire_on_commit=False
37+
)
38+
async with async_session() as session:
39+
yield session
40+
41+
async with engine.begin() as conn:
42+
await conn.run_sync(Base.metadata.drop_all)
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_connection(session: AsyncSession):
47+
result = await session.execute(text("SELECT 1"))
48+
assert result.scalar() == 1
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_create_table(session):
53+
result = await session.execute(
54+
text("SELECT name FROM sqlite_master WHERE type='table' AND name='users'")
55+
)
56+
assert result.scalar() == "users"
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_insert_and_query(session):
61+
user = User(name="Test User", email="[email protected]")
62+
session.add(user)
63+
await session.commit()
64+
65+
stmt = select(User)
66+
result = await session.execute(stmt)
67+
queried_user = result.scalars().first()
68+
assert queried_user.name == "Test User"
69+
assert queried_user.email == "[email protected]"
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_update(session):
74+
user = User(name="Test User", email="[email protected]")
75+
session.add(user)
76+
await session.commit()
77+
78+
user.name = "Updated User"
79+
await session.commit()
80+
81+
stmt = select(User)
82+
result = await session.execute(stmt)
83+
updated_user = result.scalars().first()
84+
assert updated_user.name == "Updated User"
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_delete(session):
89+
user = User(name="Test User", email="[email protected]")
90+
session.add(user)
91+
await session.commit()
92+
93+
await session.delete(user)
94+
await session.commit()
95+
96+
stmt = select(User)
97+
result = await session.execute(stmt)
98+
assert result.scalars().first() is None

0 commit comments

Comments
 (0)