Skip to content

Commit 68ae1fb

Browse files
authored
Basic Support for Sequences (#69)
* Basic Support for Sequences * Fixups for latest SQLAlchemy
1 parent 752e170 commit 68ae1fb

File tree

2 files changed

+165
-44
lines changed

2 files changed

+165
-44
lines changed

databend_sqlalchemy/databend_dialect.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import sqlalchemy.engine.reflection
3333
import sqlalchemy.types as sqltypes
34-
from typing import Any, Dict, Optional, Union
34+
from typing import Any, Dict, Optional, Union, List
3535
from sqlalchemy import util as sa_util
3636
from sqlalchemy.engine import reflection
3737
from sqlalchemy.sql import (
@@ -703,7 +703,6 @@ def __init__(self, key_type, value_type):
703703
super(MAP, self).__init__()
704704

705705

706-
707706
class DatabendDate(sqltypes.DATE):
708707
__visit_name__ = "DATE"
709708

@@ -857,7 +856,6 @@ class DatabendGeography(GEOGRAPHY):
857856
}
858857

859858

860-
861859
# Column spec
862860
colspecs = {
863861
sqltypes.Interval: DatabendInterval,
@@ -872,6 +870,12 @@ class DatabendGeography(GEOGRAPHY):
872870
class DatabendIdentifierPreparer(PGIdentifierPreparer):
873871
reserved_words = {r.lower() for r in RESERVED_WORDS}
874872

873+
# overridden to exclude schema from sequence
874+
def format_sequence(
875+
self, sequence, use_schema: bool = True
876+
) -> str:
877+
return super().format_sequence(sequence, use_schema=False)
878+
875879

876880
class DatabendCompiler(PGCompiler):
877881
iscopyintotable: bool = False
@@ -1230,6 +1234,15 @@ def copy_into_table_results(self) -> list[dict]:
12301234
def copy_into_location_results(self) -> dict:
12311235
return self._copy_into_location_results
12321236

1237+
def fire_sequence(self, seq, type_):
1238+
return self._execute_scalar(
1239+
(
1240+
"select nextval(%s)"
1241+
% self.identifier_preparer.format_sequence(seq)
1242+
),
1243+
type_,
1244+
)
1245+
12331246

12341247
class DatabendTypeCompiler(compiler.GenericTypeCompiler):
12351248
def visit_ARRAY(self, type_, **kw):
@@ -1280,7 +1293,6 @@ def visit_GEOGRAPHY(self, type_, **kw):
12801293
return "GEOGRAPHY"
12811294

12821295

1283-
12841296
class DatabendDDLCompiler(compiler.DDLCompiler):
12851297
def visit_primary_key_constraint(self, constraint, **kw):
12861298
return ""
@@ -1394,6 +1406,7 @@ class DatabendDialect(default.DefaultDialect):
13941406
supports_empty_insert = False
13951407
supports_is_distinct_from = True
13961408
supports_multivalues_insert = True
1409+
supports_sequences = True
13971410

13981411
supports_statement_cache = False
13991412
supports_server_side_cursors = True
@@ -1478,24 +1491,24 @@ def _get_default_schema_name(self, connection):
14781491
def get_schema_names(self, connection, **kw):
14791492
return [row[0] for row in connection.execute(text("SHOW DATABASES"))]
14801493

1481-
def _get_table_columns(self, connection, table_name, schema):
1482-
if schema is None:
1483-
schema = self.default_schema_name
1484-
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
1485-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1486-
1487-
return connection.execute(
1488-
text(f"DESC {quote_schema}.{quote_table_name}")
1489-
).fetchall()
1490-
14911494
@reflection.cache
14921495
def has_table(self, connection, table_name, schema=None, **kw):
1496+
table_name_query = """
1497+
select case when exists(
1498+
select table_name
1499+
from information_schema.tables
1500+
where table_schema = :schema_name
1501+
and table_name = :table_name
1502+
) then 1 else 0 end
1503+
"""
1504+
query = text(table_name_query).bindparams(
1505+
bindparam("schema_name", type_=sqltypes.Unicode),
1506+
bindparam("table_name", type_=sqltypes.Unicode),
1507+
)
14931508
if schema is None:
14941509
schema = self.default_schema_name
1495-
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
1496-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1497-
query = f"""EXISTS TABLE {quote_schema}.{quote_table_name}"""
1498-
r = connection.scalar(text(query))
1510+
1511+
r = connection.scalar(query, dict(schema_name=schema, table_name=table_name))
14991512
if r == 1:
15001513
return True
15011514
return False
@@ -1537,21 +1550,26 @@ def get_columns(self, connection, table_name, schema=None, **kw):
15371550
def get_view_definition(self, connection, view_name, schema=None, **kw):
15381551
if schema is None:
15391552
schema = self.default_schema_name
1540-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1541-
quote_view_name = self.identifier_preparer.quote_identifier(view_name)
1542-
full_view_name = f"{quote_schema}.{quote_view_name}"
1543-
1544-
# ToDo : perhaps can be removed if we get `SHOW CREATE VIEW`
1545-
if view_name not in self.get_view_names(connection, schema):
1546-
raise NoSuchTableError(full_view_name)
1547-
1548-
query = f"""SHOW CREATE TABLE {full_view_name}"""
1549-
try:
1550-
view_def = connection.execute(text(query)).first()
1551-
return view_def[1]
1552-
except DBAPIError as e:
1553-
if "1025" in e.orig.message: # ToDo: The errors need parsing properly
1554-
raise NoSuchTableError(full_view_name) from e
1553+
query = text(
1554+
"""
1555+
select view_query
1556+
from system.views
1557+
where name = :view_name
1558+
and database = :schema_name
1559+
"""
1560+
).bindparams(
1561+
bindparam("view_name", type_=sqltypes.UnicodeText),
1562+
bindparam("schema_name", type_=sqltypes.Unicode),
1563+
)
1564+
r = connection.scalar(
1565+
query, dict(view_name=view_name, schema_name=schema)
1566+
)
1567+
if not r:
1568+
raise NoSuchTableError(
1569+
f"{self.identifier_preparer.quote_identifier(schema)}."
1570+
f"{self.identifier_preparer.quote_identifier(view_name)}"
1571+
)
1572+
return r
15551573

15561574
def _get_column_type(self, column_type):
15571575
pattern = r"(?:Nullable)*(?:\()*(\w+)(?:\((.*?)\))?(?:\))*"
@@ -1621,7 +1639,6 @@ def get_temp_table_names(self, connection, schema=None, **kw):
16211639
result = connection.execute(query, dict(schema_name=schema))
16221640
return [row[0] for row in result]
16231641

1624-
16251642
@reflection.cache
16261643
def get_view_names(self, connection, schema=None, **kw):
16271644
view_name_query = """
@@ -1762,7 +1779,6 @@ def get_multi_table_comment(
17621779
schema='system',
17631780
).alias("a_tab_comments")
17641781

1765-
17661782
has_filter_names, params = self._prepare_filter_names(filter_names)
17671783
owner = schema or self.default_schema_name
17681784

@@ -1804,6 +1820,20 @@ def _check_unicode_description(self, connection):
18041820
# We decode everything as UTF-8
18051821
return True
18061822

1823+
@reflection.cache
1824+
def get_sequence_names(self, connection, schema: Optional[str] = None, **kw: Any) -> List[str]:
1825+
# N.B. sequences are not defined per schema/database
1826+
sequence_query = """
1827+
show sequences
1828+
"""
1829+
query = text(sequence_query)
1830+
result = connection.execute(query)
1831+
return [row[0] for row in result]
1832+
1833+
def has_sequence(self, connection, sequence_name: str, schema: Optional[str] = None, **kw: Any) -> bool:
1834+
# N.B. sequences are not defined per schema/database
1835+
return sequence_name in self.get_sequence_names(connection, schema, **kw)
1836+
18071837

18081838
dialect = DatabendDialect
18091839

tests/test_sqlalchemy.py

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest
1414
from sqlalchemy.testing.suite import QuotedNameArgumentTest as _QuotedNameArgumentTest
1515
from sqlalchemy.testing.suite import JoinTest as _JoinTest
16+
from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest
1617

1718
from sqlalchemy.testing.suite import ServerSideCursorsTest as _ServerSideCursorsTest
1819

@@ -21,7 +22,7 @@
2122
from sqlalchemy.testing.suite import IntegerTest as _IntegerTest
2223

2324
from sqlalchemy import types as sql_types
24-
from sqlalchemy.testing import config
25+
from sqlalchemy.testing import config, skip_test
2526
from sqlalchemy import testing, Table, Column, Integer
2627
from sqlalchemy.testing import eq_, fixtures, assertions
2728

@@ -30,7 +31,8 @@
3031
from packaging import version
3132
import sqlalchemy
3233
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
33-
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
34+
if version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
35+
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
3436
from sqlalchemy.testing.suite import EnumTest as _EnumTest
3537
else:
3638
from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest
@@ -42,14 +44,36 @@ def test_get_indexes(self):
4244
pass
4345

4446
class ComponentReflectionTestExtra(_ComponentReflectionTestExtra):
45-
47+
@testing.skip("databend") #ToDo No length in Databend
4648
@testing.requires.table_reflection
4749
def test_varchar_reflection(self, connection, metadata):
4850
typ = self._type_round_trip(
4951
connection, metadata, sql_types.String(52)
5052
)[0]
5153
assert isinstance(typ, sql_types.String)
52-
# eq_(typ.length, 52) # No length in Databend
54+
eq_(typ.length, 52)
55+
56+
@testing.skip("databend") # ToDo No length in Databend
57+
@testing.requires.table_reflection
58+
@testing.combinations(
59+
sql_types.String,
60+
sql_types.VARCHAR,
61+
sql_types.CHAR,
62+
(sql_types.NVARCHAR, testing.requires.nvarchar_types),
63+
(sql_types.NCHAR, testing.requires.nvarchar_types),
64+
argnames="type_",
65+
)
66+
def test_string_length_reflection(self, connection, metadata, type_):
67+
typ = self._type_round_trip(connection, metadata, type_(52))[0]
68+
if issubclass(type_, sql_types.VARCHAR):
69+
assert isinstance(typ, sql_types.VARCHAR)
70+
elif issubclass(type_, sql_types.CHAR):
71+
assert isinstance(typ, sql_types.CHAR)
72+
else:
73+
assert isinstance(typ, sql_types.String)
74+
75+
eq_(typ.length, 52)
76+
assert isinstance(typ.length, int)
5377

5478

5579
class BooleanTest(_BooleanTest):
@@ -204,7 +228,7 @@ def test_get_indexes(self, name):
204228
class JoinTest(_JoinTest):
205229
__requires__ = ("foreign_keys",)
206230

207-
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
231+
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0') and version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
208232
class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
209233
__requires__ = ("foreign_keys",)
210234

@@ -586,9 +610,6 @@ def test_geometry_write_and_read(self, connection):
586610
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))
587611

588612

589-
590-
591-
592613
class GeographyTest(fixtures.TablesTest):
593614

594615
@classmethod
@@ -664,4 +685,74 @@ def test_geography_write_and_read(self, connection):
664685
result = connection.execute(
665686
select(geography_table.c.geography_data).where(geography_table.c.id == 7)
666687
).scalar()
667-
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))
688+
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))
689+
690+
691+
class HasSequenceTest(_HasSequenceTest):
692+
693+
# ToDo - overridden other_seq definition due to lack of sequence ddl support for nominvalue nomaxvalue
694+
@classmethod
695+
def define_tables(cls, metadata):
696+
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
697+
normalize_sequence(
698+
config,
699+
Sequence(
700+
"other_seq",
701+
metadata=metadata,
702+
# nomaxvalue=True,
703+
# nominvalue=True,
704+
),
705+
)
706+
if testing.requires.schemas.enabled:
707+
#ToDo - omitted because Databend does not allow schema on sequence
708+
# normalize_sequence(
709+
# config,
710+
# Sequence(
711+
# "user_id_seq", schema=config.test_schema, metadata=metadata
712+
# ),
713+
# )
714+
normalize_sequence(
715+
config,
716+
Sequence(
717+
"schema_seq", schema=config.test_schema, metadata=metadata
718+
),
719+
)
720+
Table(
721+
"user_id_table",
722+
metadata,
723+
Column("id", Integer, primary_key=True),
724+
)
725+
726+
@testing.skip("databend") # ToDo - requires definition of sequences with schema
727+
def test_has_sequence_remote_not_in_default(self, connection):
728+
eq_(inspect(connection).has_sequence("schema_seq"), False)
729+
730+
@testing.skip("databend") # ToDo - requires definition of sequences with schema
731+
def test_get_sequence_names(self, connection):
732+
exp = {"other_seq", "user_id_seq"}
733+
734+
res = set(inspect(connection).get_sequence_names())
735+
is_true(res.intersection(exp) == exp)
736+
is_true("schema_seq" not in res)
737+
738+
@testing.skip("databend") # ToDo - requires definition of sequences with schema
739+
@testing.requires.schemas
740+
def test_get_sequence_names_no_sequence_schema(self, connection):
741+
eq_(
742+
inspect(connection).get_sequence_names(
743+
schema=config.test_schema_2
744+
),
745+
[],
746+
)
747+
748+
@testing.skip("databend") # ToDo - requires definition of sequences with schema
749+
@testing.requires.schemas
750+
def test_get_sequence_names_sequences_schema(self, connection):
751+
eq_(
752+
sorted(
753+
inspect(connection).get_sequence_names(
754+
schema=config.test_schema
755+
)
756+
),
757+
["schema_seq", "user_id_seq"],
758+
)

0 commit comments

Comments
 (0)