Skip to content

Commit a06c48c

Browse files
committed
feat(dtype): support compiling dtypes to sql
1 parent a92c3cb commit a06c48c

File tree

29 files changed

+183
-43
lines changed

29 files changed

+183
-43
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import operator
88
import string
99
from functools import partial, reduce
10-
from typing import TYPE_CHECKING, Any, Callable, ClassVar
10+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload
1111

1212
import sqlglot as sg
1313
import sqlglot.expressions as sge
@@ -17,6 +17,7 @@
1717
import ibis.common.patterns as pats
1818
import ibis.expr.datatypes as dt
1919
import ibis.expr.operations as ops
20+
import ibis.expr.schema as sch
2021
from ibis.backends.sql.rewrites import (
2122
FirstValue,
2223
LastValue,
@@ -53,7 +54,6 @@ def AlterTable(*args, kind="TABLE", **kwargs):
5354
if TYPE_CHECKING:
5455
from collections.abc import Iterable, Mapping
5556

56-
import ibis.expr.schema as sch
5757
import ibis.expr.types as ir
5858
from ibis.backends.sql.datatypes import SqlglotType
5959

@@ -575,16 +575,52 @@ def _prepare_params(self, params):
575575
result[node] = value
576576
return result
577577

578+
@overload
578579
def to_sqlglot(
579580
self,
580-
expr: ir.Expr,
581+
x: dt.DataType,
582+
*,
583+
params: Mapping[ir.Expr, Any] | None = None,
584+
) -> sge.DataType: ...
585+
586+
@overload
587+
def to_sqlglot(
588+
self,
589+
x: sch.Schema,
590+
*,
591+
params: Mapping[ir.Expr, Any] | None = None,
592+
) -> list[sge.ColumnDef]: ...
593+
594+
@overload
595+
def to_sqlglot(
596+
self,
597+
x: ir.Expr,
581598
*,
582599
limit: str | None = None,
583600
params: Mapping[ir.Expr, Any] | None = None,
584-
):
601+
) -> sge.Expression: ...
602+
603+
def to_sqlglot(
604+
self,
605+
x: ir.Expr | dt.DataType | sch.Schema,
606+
*,
607+
limit: str | None = None,
608+
params: Mapping[ir.Expr, Any] | None = None,
609+
) -> sge.Expression:
610+
if isinstance(x, dt.DataType | sch.Schema):
611+
return x.to_sqlglot(self.dialect)
612+
return self._to_sqlglot_expr(x, limit=limit, params=params)
613+
614+
def _to_sqlglot_expr(
615+
self,
616+
x: ir.Expr,
617+
*,
618+
limit: str | None = None,
619+
params: Mapping[ir.Expr, Any] | None = None,
620+
) -> sge.Expression:
585621
import ibis
586622

587-
table_expr = expr.as_table()
623+
table_expr = x.as_table()
588624

589625
if limit == "default":
590626
limit = ibis.options.sql.default_limit

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ibis.common.exceptions as com
1515
import ibis.expr.datatypes as dt
1616
import ibis.expr.operations as ops
17+
import ibis.expr.schema as sch
1718
from ibis import util
1819
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
1920
from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator
@@ -206,7 +207,7 @@ class BigQueryCompiler(SQLGlotCompiler):
206207

207208
def to_sqlglot(
208209
self,
209-
expr: ir.Expr,
210+
expr: ir.Expr | dt.DataType | sch.Schema,
210211
*,
211212
limit: str | None = None,
212213
params: Mapping[ir.Expr, Any] | None = None,
@@ -234,17 +235,18 @@ def to_sqlglot(
234235
Any
235236
The output of compilation. The type of this value depends on the
236237
backend.
237-
238238
"""
239-
sql = super().to_sqlglot(expr, limit=limit, params=params)
239+
if isinstance(expr, dt.DataType | sch.Schema):
240+
return expr.to_sqlglot(self.dialect)
241+
sgexpr = super().to_sqlglot(expr, limit=limit, params=params)
240242

241243
table_expr = expr.as_table()
242244

243245
memtable_names = frozenset(
244246
op.name for op in table_expr.op().find(ops.InMemoryTable)
245247
)
246248

247-
result = sql.transform(
249+
result = sgexpr.transform(
248250
_qualify_memtable,
249251
dataset=session_dataset_id,
250252
project=session_project,
@@ -257,8 +259,8 @@ def to_sqlglot(
257259
compile_func = getattr(
258260
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
259261
)
260-
if sql := compile_func(udf_node):
261-
sources.append(sql)
262+
if sgexpr := compile_func(udf_node):
263+
sources.append(sgexpr)
262264

263265
if not sources:
264266
return result

ibis/backends/sql/compilers/duckdb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,20 @@ class DuckDBCompiler(SQLGlotCompiler):
107107
ops.RandomScalar: "random",
108108
}
109109

110-
def to_sqlglot(
110+
def _to_sqlglot_expr(
111111
self,
112112
expr: ir.Expr,
113113
*,
114114
limit: str | None = None,
115115
params: Mapping[ir.Expr, Any] | None = None,
116116
):
117-
sql = super().to_sqlglot(expr, limit=limit, params=params)
117+
sgexpr = super()._to_sqlglot_expr(expr, limit=limit, params=params)
118118

119119
table_expr = expr.as_table()
120120
geocols = table_expr.schema().geospatial
121121

122122
if not geocols:
123-
return sql
123+
return sgexpr
124124

125125
quoted = self.quoted
126126
return sg.select(
@@ -132,7 +132,7 @@ def to_sqlglot(
132132
for col in geocols
133133
]
134134
)
135-
).from_(sql.subquery())
135+
).from_(sgexpr.subquery())
136136

137137
def visit_StructColumn(self, op, *, names, values):
138138
return sge.Struct.from_arg_list(

ibis/backends/sql/compilers/mssql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import Mapping
3737

38-
import ibis.expr.operations as ir
38+
from ibis.expr import types as ir
3939

4040
y = var("y")
4141
start = var("start")
@@ -156,14 +156,13 @@ def _minimize_spec(op, spec):
156156
return None
157157
return spec
158158

159-
def to_sqlglot(
159+
def _to_sqlglot_expr(
160160
self,
161161
expr: ir.Expr,
162162
*,
163163
limit: str | None = None,
164164
params: Mapping[ir.Expr, Any] | None = None,
165-
):
166-
"""Compile an Ibis expression to a sqlglot object."""
165+
) -> sge.Expression:
167166
import ibis
168167

169168
table_expr = expr.as_table()
@@ -175,7 +174,7 @@ def to_sqlglot(
175174

176175
if conversions:
177176
table_expr = table_expr.mutate(**conversions)
178-
return super().to_sqlglot(table_expr, limit=limit, params=params)
177+
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)
179178

180179
def visit_RandomScalar(self, op):
181180
# By default RAND() will generate the same value for all calls within a

ibis/backends/sql/compilers/postgres.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,20 @@ class PostgresCompiler(SQLGlotCompiler):
122122
ops.RandomUUID: "gen_random_uuid",
123123
}
124124

125-
def to_sqlglot(
125+
def _to_sqlglot_expr(
126126
self,
127127
expr: ir.Expr,
128128
*,
129129
limit: str | None = None,
130130
params: Mapping[ir.Expr, Any] | None = None,
131-
):
131+
) -> sg.Expression:
132132
table_expr = expr.as_table()
133133
geocols = table_expr.schema().geospatial
134134
conversions = {name: table_expr[name].as_ewkb() for name in geocols}
135135

136136
if conversions:
137137
table_expr = table_expr.mutate(**conversions)
138-
return super().to_sqlglot(table_expr, limit=limit, params=params)
138+
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)
139139

140140
def _compile_python_udf(self, udf_node: ops.ScalarUDF):
141141
config = udf_node.__config__

ibis/backends/sql/datatypes.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,12 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
172172
@classmethod
173173
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
174174
"""Convert an Ibis dtype to an sqlglot dtype."""
175-
176175
if method := getattr(cls, f"_from_ibis_{dtype.name}", None):
177176
return method(dtype)
178177
else:
179-
return sge.DataType(this=_to_sqlglot_types[type(dtype)])
178+
return sge.DataType(
179+
this=_to_sqlglot_types[type(dtype)], nullable=dtype.nullable
180+
)
180181

181182
@classmethod
182183
def from_string(cls, text: str, nullable: bool | None = None) -> dt.DataType:
@@ -1346,7 +1347,15 @@ class AthenaType(SqlglotType):
13461347
dialect = "athena"
13471348

13481349

1349-
TYPE_MAPPERS = {
1350+
_TYPE_MAPPERS = {
13501351
mapper.dialect: mapper
13511352
for mapper in set(get_subclasses(SqlglotType)) - {SqlglotType, BigQueryUDFType}
13521353
}
1354+
_TYPE_MAPPERS["pyspark"] = PySparkType
1355+
_TYPE_MAPPERS["druid"] = DruidType
1356+
1357+
1358+
def get_mapper(dialect: str | type[sg.Dialect]) -> type[SqlglotType]:
1359+
if not isinstance(dialect, str):
1360+
dialect = dialect.__name__.lower()
1361+
return _TYPE_MAPPERS[dialect]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Nullable(String)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR(2000000)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR(max)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR2(4000)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR

ibis/backends/tests/test_sql.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,33 @@ def test_to_sql_default_backend(con, snapshot, monkeypatch):
197197
snapshot.assert_match(ibis.to_sql(expr), "to_sql.sql")
198198

199199

200+
@pytest.mark.parametrize(
201+
"dialect",
202+
[
203+
# Just check these two to make sure that everything is plumbed through
204+
pytest.param("sqlite", marks=pytest.mark.xfail(reason="arrays not supported")),
205+
"duckdb",
206+
],
207+
)
208+
def test_to_sql_dtype_default_backend(dialect):
209+
dt = ibis.dtype("array<int64>")
210+
original_backend = ibis.get_backend()
211+
try:
212+
ibis.set_backend(dialect)
213+
sql = ibis.to_sql(dt)
214+
except:
215+
ibis.set_backend(original_backend)
216+
raise
217+
assert "BIGINT[]" == sql
218+
219+
220+
@pytest.mark.parametrize("backend_name", _get_backends_to_test(discard=("polars",)))
221+
def test_to_sql_dtype(backend_name, snapshot):
222+
dt = ibis.dtype("string")
223+
sql = ibis.to_sql(dt, dialect=backend_name)
224+
snapshot.assert_match(sql, "to_sql_dtype.sql")
225+
226+
200227
@pytest.mark.notimpl(["polars"], raises=ValueError, reason="not a SQL backend")
201228
def test_many_subqueries(backend_name, snapshot):
202229
def query(t, group_cols):

ibis/expr/datatypes/core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterable, Iterator, Mapping, Sequence
99
from numbers import Integral, Real
1010
from typing import (
11+
TYPE_CHECKING,
1112
Annotated,
1213
Any,
1314
Generic,
@@ -29,6 +30,10 @@
2930
from ibis.common.patterns import Between, Coercible, CoercionError
3031
from ibis.common.temporal import IntervalUnit, TimestampUnit
3132

33+
if TYPE_CHECKING:
34+
import sqlglot as sg
35+
import sqlglot.expressions as sge
36+
3237

3338
@lazy_singledispatch
3439
def dtype(value: Any, nullable: bool = True) -> DataType:
@@ -263,6 +268,33 @@ def from_polars(cls, polars_type, nullable=True) -> Self:
263268

264269
return PolarsType.to_ibis(polars_type, nullable=nullable)
265270

271+
def to_sqlglot(self, dialect: str | sg.Dialect) -> sge.DataType:
272+
"""Convert to the equivalent sqlglot.DataType.
273+
274+
Parameters
275+
----------
276+
dialect
277+
The SQL dialect to use.
278+
For example, some dialects convert an ibis string to TEXT,
279+
while others use VARCHAR.
280+
281+
Returns
282+
-------
283+
DataType
284+
The equivalent sqlglot.DataType.
285+
286+
Examples
287+
--------
288+
>>> import ibis
289+
>>> dt = ibis.dtype("!string")
290+
>>> dt.to_sqlglot(dialect="duckdb")
291+
DataType(this=Type.VARCHAR)
292+
"""
293+
from ibis.backends.sql.datatypes import get_mapper
294+
295+
type_mapper = get_mapper(dialect)
296+
return type_mapper.from_ibis(self)
297+
266298
def to_numpy(self):
267299
"""Return the equivalent numpy datatype."""
268300
from ibis.formats.numpy import NumpyType

0 commit comments

Comments
 (0)