Skip to content

Commit c780113

Browse files
committed
feat(dtype): support compiling dtypes to sql
1 parent d55a5ee commit c780113

File tree

11 files changed

+207
-40
lines changed

11 files changed

+207
-40
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import operator
88
import string
99
from functools import partial, reduce
10-
from typing import TYPE_CHECKING, Any, Callable, ClassVar
10+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload
1111

1212
import sqlglot as sg
1313
import sqlglot.expressions as sge
@@ -17,6 +17,7 @@
1717
import ibis.common.patterns as pats
1818
import ibis.expr.datatypes as dt
1919
import ibis.expr.operations as ops
20+
import ibis.expr.schema as sch
2021
from ibis.backends.sql.rewrites import (
2122
FirstValue,
2223
LastValue,
@@ -53,7 +54,6 @@ def AlterTable(*args, kind="TABLE", **kwargs):
5354
if TYPE_CHECKING:
5455
from collections.abc import Iterable, Mapping
5556

56-
import ibis.expr.schema as sch
5757
import ibis.expr.types as ir
5858
from ibis.backends.sql.datatypes import SqlglotType
5959

@@ -575,16 +575,52 @@ def _prepare_params(self, params):
575575
result[node] = value
576576
return result
577577

578+
@overload
578579
def to_sqlglot(
579580
self,
580-
expr: ir.Expr,
581+
x: dt.DataType,
582+
*,
583+
params: Mapping[ir.Expr, Any] | None = None,
584+
) -> sge.DataType: ...
585+
586+
@overload
587+
def to_sqlglot(
588+
self,
589+
x: sch.Schema,
590+
*,
591+
params: Mapping[ir.Expr, Any] | None = None,
592+
) -> list[sge.ColumnDef]: ...
593+
594+
@overload
595+
def to_sqlglot(
596+
self,
597+
x: ir.Expr,
581598
*,
582599
limit: str | None = None,
583600
params: Mapping[ir.Expr, Any] | None = None,
584-
):
601+
) -> sge.Expression: ...
602+
603+
def to_sqlglot(
604+
self,
605+
x: ir.Expr | dt.DataType | sch.Schema,
606+
*,
607+
limit: str | None = None,
608+
params: Mapping[ir.Expr, Any] | None = None,
609+
) -> sge.Expression:
610+
if isinstance(x, (dt.DataType, sch.Schema)):
611+
return x.to_sqlglot(self.dialect)
612+
return self._to_sqlglot_expr(x, limit=limit, params=params)
613+
614+
def _to_sqlglot_expr(
615+
self,
616+
x: ir.Expr,
617+
*,
618+
limit: str | None = None,
619+
params: Mapping[ir.Expr, Any] | None = None,
620+
) -> sge.Expression:
585621
import ibis
586622

587-
table_expr = expr.as_table()
623+
table_expr = x.as_table()
588624

589625
if limit == "default":
590626
limit = ibis.options.sql.default_limit

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ibis.common.exceptions as com
1515
import ibis.expr.datatypes as dt
1616
import ibis.expr.operations as ops
17+
import ibis.expr.schema as sch
1718
from ibis import util
1819
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
1920
from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator
@@ -206,7 +207,7 @@ class BigQueryCompiler(SQLGlotCompiler):
206207

207208
def to_sqlglot(
208209
self,
209-
expr: ir.Expr,
210+
expr: ir.Expr | dt.DataType | sch.Schema,
210211
*,
211212
limit: str | None = None,
212213
params: Mapping[ir.Expr, Any] | None = None,
@@ -234,17 +235,18 @@ def to_sqlglot(
234235
Any
235236
The output of compilation. The type of this value depends on the
236237
backend.
237-
238238
"""
239-
sql = super().to_sqlglot(expr, limit=limit, params=params)
239+
if isinstance(expr, (dt.DataType, sch.Schema)):
240+
return expr.to_sqlglot(self.dialect)
241+
sgexpr = super().to_sqlglot(expr, limit=limit, params=params)
240242

241243
table_expr = expr.as_table()
242244

243245
memtable_names = frozenset(
244246
op.name for op in table_expr.op().find(ops.InMemoryTable)
245247
)
246248

247-
result = sql.transform(
249+
result = sgexpr.transform(
248250
_qualify_memtable,
249251
dataset=session_dataset_id,
250252
project=session_project,
@@ -257,8 +259,8 @@ def to_sqlglot(
257259
compile_func = getattr(
258260
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
259261
)
260-
if sql := compile_func(udf_node):
261-
sources.append(sql)
262+
if sgexpr := compile_func(udf_node):
263+
sources.append(sgexpr)
262264

263265
if not sources:
264266
return result

ibis/backends/sql/compilers/duckdb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,20 @@ class DuckDBCompiler(SQLGlotCompiler):
107107
ops.RandomScalar: "random",
108108
}
109109

110-
def to_sqlglot(
110+
def _to_sqlglot_expr(
111111
self,
112112
expr: ir.Expr,
113113
*,
114114
limit: str | None = None,
115115
params: Mapping[ir.Expr, Any] | None = None,
116116
):
117-
sql = super().to_sqlglot(expr, limit=limit, params=params)
117+
sgexpr = super()._to_sqlglot_expr(expr, limit=limit, params=params)
118118

119119
table_expr = expr.as_table()
120120
geocols = table_expr.schema().geospatial
121121

122122
if not geocols:
123-
return sql
123+
return sgexpr
124124

125125
quoted = self.quoted
126126
return sg.select(
@@ -132,7 +132,7 @@ def to_sqlglot(
132132
for col in geocols
133133
]
134134
)
135-
).from_(sql.subquery())
135+
).from_(sgexpr.subquery())
136136

137137
def visit_StructColumn(self, op, *, names, values):
138138
return sge.Struct.from_arg_list(

ibis/backends/sql/compilers/mssql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import Mapping
3737

38-
import ibis.expr.operations as ir
38+
from ibis.expr import types as ir
3939

4040
y = var("y")
4141
start = var("start")
@@ -156,14 +156,13 @@ def _minimize_spec(op, spec):
156156
return None
157157
return spec
158158

159-
def to_sqlglot(
159+
def _to_sqlglot_expr(
160160
self,
161161
expr: ir.Expr,
162162
*,
163163
limit: str | None = None,
164164
params: Mapping[ir.Expr, Any] | None = None,
165-
):
166-
"""Compile an Ibis expression to a sqlglot object."""
165+
) -> sge.Expression:
167166
import ibis
168167

169168
table_expr = expr.as_table()
@@ -175,7 +174,7 @@ def to_sqlglot(
175174

176175
if conversions:
177176
table_expr = table_expr.mutate(**conversions)
178-
return super().to_sqlglot(table_expr, limit=limit, params=params)
177+
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)
179178

180179
def visit_RandomScalar(self, op):
181180
# By default RAND() will generate the same value for all calls within a

ibis/backends/sql/compilers/postgres.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ class PostgresCompiler(SQLGlotCompiler):
121121
ops.RandomUUID: "gen_random_uuid",
122122
}
123123

124-
def to_sqlglot(
124+
def _to_sqlglot_expr(
125125
self,
126126
expr: ir.Expr,
127127
*,
128128
limit: str | None = None,
129129
params: Mapping[ir.Expr, Any] | None = None,
130-
):
130+
) -> sg.Expression:
131131
table_expr = expr.as_table()
132132
schema = table_expr.schema()
133133

@@ -140,7 +140,7 @@ def to_sqlglot(
140140

141141
if conversions:
142142
table_expr = table_expr.mutate(**conversions)
143-
return super().to_sqlglot(table_expr, limit=limit, params=params)
143+
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)
144144

145145
def _compile_python_udf(self, udf_node: ops.ScalarUDF):
146146
config = udf_node.__config__

ibis/backends/sql/compilers/risingwave.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class RisingWaveCompiler(PostgresCompiler):
5353

5454
del SIMPLE_OPS[ops.MapContains]
5555

56-
def to_sqlglot(
56+
def _to_sqlglot_expr(
5757
self,
5858
expr: ir.Expr,
5959
*,

ibis/backends/sql/datatypes.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,12 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
172172
@classmethod
173173
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
174174
"""Convert an Ibis dtype to an sqlglot dtype."""
175-
176175
if method := getattr(cls, f"_from_ibis_{dtype.name}", None):
177176
return method(dtype)
178177
else:
179-
return sge.DataType(this=_to_sqlglot_types[type(dtype)])
178+
return sge.DataType(
179+
this=_to_sqlglot_types[type(dtype)], nullable=dtype.nullable
180+
)
180181

181182
@classmethod
182183
def from_string(cls, text: str, nullable: bool | None = None) -> dt.DataType:
@@ -1360,7 +1361,19 @@ class AthenaType(SqlglotType):
13601361
dialect = "athena"
13611362

13621363

1363-
TYPE_MAPPERS: dict[str, SqlglotType] = {
1364+
_TYPE_MAPPERS: dict[str, type[SqlglotType]] = {
13641365
mapper.dialect: mapper
13651366
for mapper in set(get_subclasses(SqlglotType)) - {SqlglotType, BigQueryUDFType}
13661367
}
1368+
_TYPE_MAPPERS["pyspark"] = PySparkType
1369+
_TYPE_MAPPERS["druid"] = DruidType
1370+
1371+
1372+
def get_mapper(dialect: str | type[sg.Dialect] | None) -> type[SqlglotType]:
1373+
import ibis
1374+
1375+
if dialect is None:
1376+
dialect = ibis.get_backend().dialect
1377+
if not isinstance(dialect, str):
1378+
dialect = dialect.__name__.lower()
1379+
return _TYPE_MAPPERS[dialect]

ibis/backends/tests/test_sql.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import re
45

56
import pytest
@@ -197,6 +198,62 @@ def test_to_sql_default_backend(con, snapshot, monkeypatch):
197198
snapshot.assert_match(ibis.to_sql(expr), "to_sql.sql")
198199

199200

201+
@contextlib.contextmanager
202+
def with_default_backend(backend: str):
203+
original_backend = ibis.get_backend()
204+
try:
205+
ibis.set_backend(backend)
206+
yield
207+
finally:
208+
ibis.set_backend(original_backend)
209+
210+
211+
@pytest.mark.parametrize(
212+
"dialect",
213+
[
214+
# Just check these two to make sure that everything is plumbed through
215+
pytest.param("sqlite", marks=pytest.mark.xfail(reason="arrays not supported")),
216+
"duckdb",
217+
],
218+
)
219+
def test_to_sql_dtype_default_backend(dialect):
220+
dt = ibis.dtype("array<int64>")
221+
with with_default_backend(dialect):
222+
sql = ibis.to_sql(dt)
223+
assert "BIGINT[]" == sql
224+
225+
226+
STRING_DTYPES = {
227+
"athena": "VARCHAR",
228+
"bigquery": "STRING",
229+
"clickhouse": "Nullable(String)",
230+
"databricks": "STRING",
231+
"datafusion": "VARCHAR",
232+
"druid": "VARCHAR",
233+
"duckdb": "TEXT",
234+
"exasol": "VARCHAR(2000000)",
235+
"flink": "STRING",
236+
"impala": "STRING",
237+
"mssql": "VARCHAR(max)",
238+
"mysql": "TEXT",
239+
"oracle": "VARCHAR2(4000)",
240+
"postgres": "VARCHAR",
241+
"pyspark": "STRING",
242+
"risingwave": "VARCHAR",
243+
"snowflake": "VARCHAR",
244+
"sqlite": "TEXT",
245+
"trino": "VARCHAR",
246+
}
247+
248+
249+
@pytest.mark.parametrize("backend_name", _get_backends_to_test(discard=("polars",)))
250+
def test_to_sql_dtype(backend_name):
251+
dt = ibis.dtype("string")
252+
sql = ibis.to_sql(dt, dialect=backend_name)
253+
expected = STRING_DTYPES[backend_name]
254+
assert expected == sql
255+
256+
200257
@pytest.mark.notimpl(["polars"], raises=ValueError, reason="not a SQL backend")
201258
def test_many_subqueries(backend_name, snapshot):
202259
def query(t, group_cols):

ibis/expr/datatypes/core.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@ def dtype(
8383
) -> DataType: ...
8484

8585

86+
if TYPE_CHECKING:
87+
import numpy as np
88+
import polars as pl
89+
import pyarrow as pa
90+
import sqlglot as sg
91+
import sqlglot.expressions as sge
92+
from pandas.api.extensions import ExtensionDtype
93+
94+
95+
if TYPE_CHECKING:
96+
import sqlglot as sg
97+
import sqlglot.expressions as sge
98+
99+
86100
@lazy_singledispatch
87101
def dtype(value, nullable=True) -> DataType:
88102
"""Create a DataType object.
@@ -318,6 +332,34 @@ def from_polars(cls, polars_type: pl.DataType, nullable: bool = True) -> Self:
318332

319333
return PolarsType.to_ibis(polars_type, nullable=nullable)
320334

335+
def to_sqlglot(self, dialect: str | sg.Dialect | None = None) -> sge.DataType:
336+
"""Convert to the equivalent sqlglot.DataType.
337+
338+
Parameters
339+
----------
340+
dialect
341+
The SQL dialect to use.
342+
For example, some dialects convert an ibis string to TEXT,
343+
while others use VARCHAR.
344+
If not provided, the dialect from `ibis.get_backend()` is used.
345+
346+
Returns
347+
-------
348+
DataType
349+
The equivalent sqlglot.DataType.
350+
351+
Examples
352+
--------
353+
>>> import ibis
354+
>>> dt = ibis.dtype("!string")
355+
>>> dt.to_sqlglot(dialect="duckdb")
356+
DataType(this=Type.VARCHAR)
357+
"""
358+
from ibis.backends.sql.datatypes import get_mapper
359+
360+
type_mapper = get_mapper(dialect)
361+
return type_mapper.from_ibis(self)
362+
321363
def to_numpy(self) -> np.dtype:
322364
"""Return the equivalent numpy datatype."""
323365
from ibis.formats.numpy import NumpyType

0 commit comments

Comments
 (0)