3131
3232import sqlalchemy .engine .reflection
3333import sqlalchemy .types as sqltypes
34- from typing import Any , Dict , Optional , Union
34+ from typing import Any , Dict , Optional , Union , List
3535from sqlalchemy import util as sa_util
3636from sqlalchemy .engine import reflection
3737from sqlalchemy .sql import (
@@ -703,7 +703,6 @@ def __init__(self, key_type, value_type):
703703 super (MAP , self ).__init__ ()
704704
705705
706-
707706class DatabendDate (sqltypes .DATE ):
708707 __visit_name__ = "DATE"
709708
@@ -857,7 +856,6 @@ class DatabendGeography(GEOGRAPHY):
857856}
858857
859858
860-
861859# Column spec
862860colspecs = {
863861 sqltypes .Interval : DatabendInterval ,
@@ -872,6 +870,12 @@ class DatabendGeography(GEOGRAPHY):
872870class DatabendIdentifierPreparer (PGIdentifierPreparer ):
873871 reserved_words = {r .lower () for r in RESERVED_WORDS }
874872
873+ # overridden to exclude schema from sequence
874+ def format_sequence (
875+ self , sequence , use_schema : bool = True
876+ ) -> str :
877+ return super ().format_sequence (sequence , use_schema = False )
878+
875879
876880class DatabendCompiler (PGCompiler ):
877881 iscopyintotable : bool = False
@@ -1230,6 +1234,15 @@ def copy_into_table_results(self) -> list[dict]:
12301234 def copy_into_location_results (self ) -> dict :
12311235 return self ._copy_into_location_results
12321236
1237+ def fire_sequence (self , seq , type_ ):
1238+ return self ._execute_scalar (
1239+ (
1240+ "select nextval(%s)"
1241+ % self .identifier_preparer .format_sequence (seq )
1242+ ),
1243+ type_ ,
1244+ )
1245+
12331246
12341247class DatabendTypeCompiler (compiler .GenericTypeCompiler ):
12351248 def visit_ARRAY (self , type_ , ** kw ):
@@ -1280,7 +1293,6 @@ def visit_GEOGRAPHY(self, type_, **kw):
12801293 return "GEOGRAPHY"
12811294
12821295
1283-
12841296class DatabendDDLCompiler (compiler .DDLCompiler ):
12851297 def visit_primary_key_constraint (self , constraint , ** kw ):
12861298 return ""
@@ -1394,6 +1406,7 @@ class DatabendDialect(default.DefaultDialect):
13941406 supports_empty_insert = False
13951407 supports_is_distinct_from = True
13961408 supports_multivalues_insert = True
1409+ supports_sequences = True
13971410
13981411 supports_statement_cache = False
13991412 supports_server_side_cursors = True
@@ -1478,24 +1491,24 @@ def _get_default_schema_name(self, connection):
14781491 def get_schema_names (self , connection , ** kw ):
14791492 return [row [0 ] for row in connection .execute (text ("SHOW DATABASES" ))]
14801493
1481- def _get_table_columns (self , connection , table_name , schema ):
1482- if schema is None :
1483- schema = self .default_schema_name
1484- quote_table_name = self .identifier_preparer .quote_identifier (table_name )
1485- quote_schema = self .identifier_preparer .quote_identifier (schema )
1486-
1487- return connection .execute (
1488- text (f"DESC { quote_schema } .{ quote_table_name } " )
1489- ).fetchall ()
1490-
14911494 @reflection .cache
14921495 def has_table (self , connection , table_name , schema = None , ** kw ):
1496+ table_name_query = """
1497+ select case when exists(
1498+ select table_name
1499+ from information_schema.tables
1500+ where table_schema = :schema_name
1501+ and table_name = :table_name
1502+ ) then 1 else 0 end
1503+ """
1504+ query = text (table_name_query ).bindparams (
1505+ bindparam ("schema_name" , type_ = sqltypes .Unicode ),
1506+ bindparam ("table_name" , type_ = sqltypes .Unicode ),
1507+ )
14931508 if schema is None :
14941509 schema = self .default_schema_name
1495- quote_table_name = self .identifier_preparer .quote_identifier (table_name )
1496- quote_schema = self .identifier_preparer .quote_identifier (schema )
1497- query = f"""EXISTS TABLE { quote_schema } .{ quote_table_name } """
1498- r = connection .scalar (text (query ))
1510+
1511+ r = connection .scalar (query , dict (schema_name = schema , table_name = table_name ))
14991512 if r == 1 :
15001513 return True
15011514 return False
@@ -1537,21 +1550,26 @@ def get_columns(self, connection, table_name, schema=None, **kw):
15371550 def get_view_definition (self , connection , view_name , schema = None , ** kw ):
15381551 if schema is None :
15391552 schema = self .default_schema_name
1540- quote_schema = self .identifier_preparer .quote_identifier (schema )
1541- quote_view_name = self .identifier_preparer .quote_identifier (view_name )
1542- full_view_name = f"{ quote_schema } .{ quote_view_name } "
1543-
1544- # ToDo : perhaps can be removed if we get `SHOW CREATE VIEW`
1545- if view_name not in self .get_view_names (connection , schema ):
1546- raise NoSuchTableError (full_view_name )
1547-
1548- query = f"""SHOW CREATE TABLE { full_view_name } """
1549- try :
1550- view_def = connection .execute (text (query )).first ()
1551- return view_def [1 ]
1552- except DBAPIError as e :
1553- if "1025" in e .orig .message : # ToDo: The errors need parsing properly
1554- raise NoSuchTableError (full_view_name ) from e
1553+ query = text (
1554+ """
1555+ select view_query
1556+ from system.views
1557+ where name = :view_name
1558+ and database = :schema_name
1559+ """
1560+ ).bindparams (
1561+ bindparam ("view_name" , type_ = sqltypes .UnicodeText ),
1562+ bindparam ("schema_name" , type_ = sqltypes .Unicode ),
1563+ )
1564+ r = connection .scalar (
1565+ query , dict (view_name = view_name , schema_name = schema )
1566+ )
1567+ if not r :
1568+ raise NoSuchTableError (
1569+ f"{ self .identifier_preparer .quote_identifier (schema )} ."
1570+ f"{ self .identifier_preparer .quote_identifier (view_name )} "
1571+ )
1572+ return r
15551573
15561574 def _get_column_type (self , column_type ):
15571575 pattern = r"(?:Nullable)*(?:\()*(\w+)(?:\((.*?)\))?(?:\))*"
@@ -1621,7 +1639,6 @@ def get_temp_table_names(self, connection, schema=None, **kw):
16211639 result = connection .execute (query , dict (schema_name = schema ))
16221640 return [row [0 ] for row in result ]
16231641
1624-
16251642 @reflection .cache
16261643 def get_view_names (self , connection , schema = None , ** kw ):
16271644 view_name_query = """
@@ -1762,7 +1779,6 @@ def get_multi_table_comment(
17621779 schema = 'system' ,
17631780 ).alias ("a_tab_comments" )
17641781
1765-
17661782 has_filter_names , params = self ._prepare_filter_names (filter_names )
17671783 owner = schema or self .default_schema_name
17681784
@@ -1804,6 +1820,20 @@ def _check_unicode_description(self, connection):
18041820 # We decode everything as UTF-8
18051821 return True
18061822
1823+ @reflection .cache
1824+ def get_sequence_names (self , connection , schema : Optional [str ] = None , ** kw : Any ) -> List [str ]:
1825+ # N.B. sequences are not defined per schema/database
1826+ sequence_query = """
1827+ show sequences
1828+ """
1829+ query = text (sequence_query )
1830+ result = connection .execute (query )
1831+ return [row [0 ] for row in result ]
1832+
1833+ def has_sequence (self , connection , sequence_name : str , schema : Optional [str ] = None , ** kw : Any ) -> bool :
1834+ # N.B. sequences are not defined per schema/database
1835+ return sequence_name in self .get_sequence_names (connection , schema , ** kw )
1836+
18071837
18081838dialect = DatabendDialect
18091839
0 commit comments