Skip to content

Commit c77ea60

Browse files
authored
Merge pull request #135 from ErwonB/main
fix: teradata adapter + resultset
2 parents 49da2d0 + eef679f commit c77ea60

1 file changed

Lines changed: 50 additions & 30 deletions

File tree

  • sqlit/domains/connections/providers/teradata

sqlit/domains/connections/providers/teradata/adapter.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from __future__ import annotations
44

5+
import re
56
from typing import TYPE_CHECKING, Any
67

78
from sqlit.domains.connections.providers.adapters.base import (
89
ColumnInfo,
910
CursorBasedAdapter,
1011
IndexInfo,
11-
SequenceInfo,
1212
TableInfo,
1313
TriggerInfo,
1414
)
@@ -49,9 +49,27 @@ def supports_cross_database_queries(self) -> bool:
4949
def supports_stored_procedures(self) -> bool:
5050
return True
5151

52-
@property
53-
def supports_sequences(self) -> bool:
54-
return True
52+
_TERADATA_SELECT_KEYWORDS = frozenset(
53+
{"SELECT", "SEL", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "HELP"}
54+
)
55+
56+
_LOCKING_RE = re.compile(
57+
r"\bFOR\s+(?:ACCESS|READ|WRITE|EXCLUSIVE)(?:\s+NOWAIT)?\s+(\w+)",
58+
re.IGNORECASE,
59+
)
60+
61+
def classify_query(self, query: str) -> bool:
62+
"""Classify Teradata queries, handling LOCKING/LOCK prefix and SEL abbreviation."""
63+
query_upper = query.strip().upper()
64+
first_word = query_upper.split()[0] if query_upper else ""
65+
66+
# Strip LOCKING/LOCK request modifier to find the actual statement keyword
67+
if first_word in ("LOCKING", "LOCK"):
68+
match = self._LOCKING_RE.search(query_upper)
69+
if match:
70+
first_word = match.group(1)
71+
72+
return first_word in self._TERADATA_SELECT_KEYWORDS
5573

5674
def apply_database_override(self, config: ConnectionConfig, database: str) -> ConnectionConfig:
5775
"""Apply a default database for unqualified queries."""
@@ -91,8 +109,9 @@ def connect(self, config: ConnectionConfig) -> Any:
91109
def get_databases(self, conn: Any) -> list[str]:
92110
cursor = conn.cursor()
93111
cursor.execute(
112+
"lock row for access "
94113
"SELECT DatabaseName FROM DBC.DatabasesV "
95-
"WHERE DatabaseKind IN ('D', 'U') "
114+
"WHERE dbkind IN ('D', 'U') "
96115
"ORDER BY DatabaseName"
97116
)
98117
return [row[0] for row in cursor.fetchall()]
@@ -101,13 +120,15 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]:
101120
cursor = conn.cursor()
102121
if database:
103122
cursor.execute(
123+
"lock row for access "
104124
"SELECT DatabaseName, TableName FROM DBC.TablesV "
105125
"WHERE TableKind = 'T' AND DatabaseName = ? "
106126
"ORDER BY TableName",
107127
(database,),
108128
)
109129
else:
110130
cursor.execute(
131+
"lock row for access "
111132
"SELECT DatabaseName, TableName FROM DBC.TablesV "
112133
"WHERE TableKind = 'T' "
113134
"ORDER BY DatabaseName, TableName"
@@ -118,13 +139,15 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]:
118139
cursor = conn.cursor()
119140
if database:
120141
cursor.execute(
142+
"lock row for access "
121143
"SELECT DatabaseName, TableName FROM DBC.TablesV "
122144
"WHERE TableKind = 'V' AND DatabaseName = ? "
123145
"ORDER BY TableName",
124146
(database,),
125147
)
126148
else:
127149
cursor.execute(
150+
"lock row for access "
128151
"SELECT DatabaseName, TableName FROM DBC.TablesV "
129152
"WHERE TableKind = 'V' "
130153
"ORDER BY DatabaseName, TableName"
@@ -142,21 +165,21 @@ def get_columns(
142165
pk_columns: set[str] = set()
143166
try:
144167
cursor.execute(
145-
"SELECT ic.ColumnName "
146-
"FROM DBC.IndexConstraintsV c "
147-
"JOIN DBC.IndexColumnsV ic "
148-
" ON c.DatabaseName = ic.DatabaseName "
149-
" AND c.TableName = ic.TableName "
150-
" AND c.IndexNumber = ic.IndexNumber "
151-
"WHERE c.ConstraintType = 'P' "
152-
"AND c.DatabaseName = ? AND c.TableName = ?",
168+
"lock row for access "
169+
"select "
170+
"COLUMNNAME "
171+
"from DBC.INDICESV "
172+
"where DATABASENAME = ? "
173+
"and TABLENAME = ? "
174+
"and INDEXTYPE = 'P' ",
153175
(schema_name, table),
154176
)
155177
pk_columns = {row[0] for row in cursor.fetchall()}
156178
except Exception:
157179
pk_columns = set()
158180

159181
cursor.execute(
182+
"lock row for access "
160183
"SELECT ColumnName, ColumnType FROM DBC.ColumnsV "
161184
"WHERE DatabaseName = ? AND TableName = ? "
162185
"ORDER BY ColumnId",
@@ -171,13 +194,15 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]:
171194
cursor = conn.cursor()
172195
if database:
173196
cursor.execute(
197+
"lock row for access "
174198
"SELECT TableName FROM DBC.TablesV "
175199
"WHERE TableKind = 'P' AND DatabaseName = ? "
176200
"ORDER BY TableName",
177201
(database,),
178202
)
179203
else:
180204
cursor.execute(
205+
"lock row for access "
181206
"SELECT TableName FROM DBC.TablesV "
182207
"WHERE TableKind = 'P' "
183208
"ORDER BY TableName"
@@ -188,13 +213,15 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]
188213
cursor = conn.cursor()
189214
if database:
190215
cursor.execute(
216+
"lock row for access "
191217
"SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV "
192218
"WHERE DatabaseName = ? "
193219
"ORDER BY TableName, IndexName",
194220
(database,),
195221
)
196222
else:
197223
cursor.execute(
224+
"lock row for access "
198225
"SELECT IndexName, TableName, UniqueFlag FROM DBC.IndicesV "
199226
"ORDER BY DatabaseName, TableName, IndexName"
200227
)
@@ -207,33 +234,26 @@ def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerIn
207234
cursor = conn.cursor()
208235
if database:
209236
cursor.execute(
237+
"lock row for access "
210238
"SELECT TriggerName, TableName FROM DBC.TriggersV "
211239
"WHERE DatabaseName = ? "
212240
"ORDER BY TableName, TriggerName",
213241
(database,),
214242
)
215243
else:
216244
cursor.execute(
245+
"lock row for access "
217246
"SELECT TriggerName, TableName FROM DBC.TriggersV "
218247
"ORDER BY DatabaseName, TableName, TriggerName"
219248
)
220249
return [TriggerInfo(name=row[0], table_name=row[1]) for row in cursor.fetchall()]
221250

222-
def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]:
223-
cursor = conn.cursor()
224-
if database:
225-
cursor.execute(
226-
"SELECT SequenceName FROM DBC.SequencesV "
227-
"WHERE DatabaseName = ? "
228-
"ORDER BY SequenceName",
229-
(database,),
230-
)
231-
else:
232-
cursor.execute(
233-
"SELECT SequenceName FROM DBC.SequencesV "
234-
"ORDER BY DatabaseName, SequenceName"
235-
)
236-
return [SequenceInfo(name=row[0]) for row in cursor.fetchall()]
251+
def get_sequences(self, conn: Any, database: str | None = None) -> list[str]:
252+
"""Teradata does not support standalone sequences.
253+
254+
Auto-increment behaviour is provided by IDENTITY columns instead.
255+
"""
256+
return []
237257

238258
def quote_identifier(self, name: str) -> str:
239259
escaped = name.replace('"', '""')
@@ -242,5 +262,5 @@ def quote_identifier(self, name: str) -> str:
242262
def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str:
243263
schema_name = schema or database
244264
if schema_name:
245-
return f'SELECT TOP {limit} * FROM "{schema_name}"."{table}"'
246-
return f'SELECT TOP {limit} * FROM "{table}"'
265+
return f'lock row for access select top {limit} * from "{schema_name}"."{table}"'
266+
return f'lock row for access select top {limit} * from "{table}"'

0 commit comments

Comments
 (0)