Skip to content

Commit 16abc2c

Browse files
committed
feat(dtype): support compiling dtypes to sql
1 parent 20bec13 commit 16abc2c

File tree

27 files changed

+115
-33
lines changed

27 files changed

+115
-33
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,17 @@ def _prepare_params(self, params):
577577

578578
def to_sqlglot(
579579
self,
580-
expr: ir.Expr,
580+
x: ir.Expr | dt.DataType,
581581
*,
582582
limit: str | None = None,
583583
params: Mapping[ir.Expr, Any] | None = None,
584-
):
584+
) -> sge.Expression:
585585
import ibis
586586

587-
table_expr = expr.as_table()
587+
if isinstance(x, dt.DataType):
588+
return self.type_mapper.from_ibis(x)
589+
590+
table_expr = x.as_table()
588591

589592
if limit == "default":
590593
limit = ibis.options.sql.default_limit

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class BigQueryCompiler(SQLGlotCompiler):
206206

207207
def to_sqlglot(
208208
self,
209-
expr: ir.Expr,
209+
x: ir.Expr | dt.DataType,
210210
*,
211211
limit: str | None = None,
212212
params: Mapping[ir.Expr, Any] | None = None,
@@ -217,8 +217,8 @@ def to_sqlglot(
217217
218218
Parameters
219219
----------
220-
expr
221-
Ibis expression
220+
x
221+
Ibis expression or data type.
222222
limit
223223
For expressions yielding result sets; retrieve at most this number
224224
of values/rows. Overrides any limit already set on the expression.
@@ -236,9 +236,11 @@ def to_sqlglot(
236236
backend.
237237
238238
"""
239-
sql = super().to_sqlglot(expr, limit=limit, params=params)
239+
sql = super().to_sqlglot(x, limit=limit, params=params)
240+
if isinstance(x, dt.DataType):
241+
return sql
240242

241-
table_expr = expr.as_table()
243+
table_expr = x.as_table()
242244

243245
memtable_names = frozenset(
244246
op.name for op in table_expr.op().find(ops.InMemoryTable)

ibis/backends/sql/compilers/duckdb.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,16 @@ class DuckDBCompiler(SQLGlotCompiler):
109109

110110
def to_sqlglot(
111111
self,
112-
expr: ir.Expr,
112+
x: ir.Expr | dt.DataType,
113113
*,
114114
limit: str | None = None,
115115
params: Mapping[ir.Expr, Any] | None = None,
116-
):
117-
sql = super().to_sqlglot(expr, limit=limit, params=params)
116+
) -> sge.Expression:
117+
sql = super().to_sqlglot(x, limit=limit, params=params)
118+
if isinstance(x, dt.DataType):
119+
return sql
118120

119-
table_expr = expr.as_table()
121+
table_expr = x.as_table()
120122
geocols = table_expr.schema().geospatial
121123

122124
if not geocols:

ibis/backends/sql/compilers/mssql.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,17 @@ def _minimize_spec(op, spec):
158158

159159
def to_sqlglot(
160160
self,
161-
expr: ir.Expr,
161+
x: ir.Expr | dt.DataType,
162162
*,
163163
limit: str | None = None,
164164
params: Mapping[ir.Expr, Any] | None = None,
165-
):
166-
"""Compile an Ibis expression to a sqlglot object."""
165+
) -> sge.Expression:
167166
import ibis
168167

169-
table_expr = expr.as_table()
168+
if isinstance(x, dt.DataType):
169+
return super().to_sqlglot(x)
170+
171+
table_expr = x.as_table()
170172
conversions = {
171173
name: ibis.ifelse(table_expr[name], 1, 0).cast(dt.boolean)
172174
for name, typ in table_expr.schema().items()

ibis/backends/sql/compilers/postgres.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,15 @@ class PostgresCompiler(SQLGlotCompiler):
124124

125125
def to_sqlglot(
126126
self,
127-
expr: ir.Expr,
127+
x: ir.Expr | dt.DataType,
128128
*,
129129
limit: str | None = None,
130130
params: Mapping[ir.Expr, Any] | None = None,
131-
):
132-
table_expr = expr.as_table()
131+
) -> sg.Expression:
132+
if isinstance(x, dt.DataType):
133+
return super().to_sqlglot(x)
134+
135+
table_expr = x.as_table()
133136
geocols = table_expr.schema().geospatial
134137
conversions = {name: table_expr[name].as_ewkb() for name in geocols}
135138

ibis/backends/sql/datatypes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,14 @@ def _from_sqlglot_DECIMAL(
769769

770770
@classmethod
771771
def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType:
772-
nullable = " NOT NULL" if not dtype.nullable else ""
773-
return "VARCHAR2(4000)" + nullable
772+
return sge.DataType(
773+
this=typecode.VARCHAR,
774+
expressions=[
775+
sge.DataTypeParam(
776+
this=sge.convert(dtype.length if dtype.length is not None else 4000)
777+
)
778+
],
779+
)
774780

775781

776782
class SnowflakeType(SqlglotType):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Nullable(String)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR(2000000)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR(max)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR2(4000)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
STRING
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TEXT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VARCHAR

ibis/backends/tests/test_sql.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,33 @@ def test_to_sql_default_backend(con, snapshot, monkeypatch):
197197
snapshot.assert_match(ibis.to_sql(expr), "to_sql.sql")
198198

199199

200+
@pytest.mark.parametrize(
201+
"dialect",
202+
[
203+
# Just check these two to make sure that everything is plumbed through
204+
pytest.param("sqlite", marks=pytest.mark.xfail(reason="arrays not supported")),
205+
"duckdb",
206+
],
207+
)
208+
def test_to_sql_dtype_default_backend(dialect):
209+
dt = ibis.dtype("array<int64>")
210+
original_backend = ibis.get_backend()
211+
try:
212+
ibis.set_backend(dialect)
213+
sql = ibis.to_sql(dt)
214+
except:
215+
ibis.set_backend(original_backend)
216+
raise
217+
assert "BIGINT[]" == sql
218+
219+
220+
@pytest.mark.parametrize("backend_name", _get_backends_to_test(discard=("polars",)))
221+
def test_to_sql_dtype(backend_name, snapshot):
222+
dt = ibis.dtype("string")
223+
sql = ibis.to_sql(dt, dialect=backend_name)
224+
snapshot.assert_match(sql, "to_sql_dtype.sql")
225+
226+
200227
@pytest.mark.notimpl(["polars"], raises=ValueError, reason="not a SQL backend")
201228
def test_many_subqueries(backend_name, snapshot):
202229
def query(t, group_cols):

ibis/expr/sql.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import operator
55
from functools import singledispatch
6+
from typing import TYPE_CHECKING
67

78
import sqlglot as sg
89
import sqlglot.expressions as sge
@@ -18,6 +19,9 @@
1819
from ibis.backends.sql.datatypes import SqlglotType
1920
from ibis.util import experimental
2021

22+
if TYPE_CHECKING:
23+
from ibis.backends.sql.compilers.base import SQLGlotCompiler
24+
2125

2226
class Catalog(dict[str, sch.Schema]):
2327
"""A catalog of tables and their schemas."""
@@ -445,14 +449,18 @@ def _repr_pretty_(self, p, cycle) -> str:
445449

446450
@public
447451
def to_sql(
448-
expr: ir.Expr, dialect: str | None = None, pretty: bool = True, **kwargs
452+
x: ir.Expr | dt.DataType,
453+
dialect: str | None = None,
454+
*,
455+
pretty: bool = True,
456+
**kwargs,
449457
) -> SQLString:
450-
"""Return the formatted SQL string for an expression.
458+
"""Return the formatted SQL string for an expression or data type.
451459
452460
Parameters
453461
----------
454-
expr
455-
Ibis expression.
462+
x
463+
Ibis expression or data type.
456464
dialect
457465
SQL dialect to use for compilation.
458466
pretty
@@ -484,6 +492,8 @@ def to_sql(
484492
`t0`.`b`,
485493
`t0`.`a` + `t0`.`b` AS `c`
486494
FROM `t` AS `t0`
495+
>>> ibis.to_sql(ibis.dtype("array<int64>"), dialect="duckdb")
496+
'BIGINT[]'
487497
488498
See Also
489499
--------
@@ -495,23 +505,31 @@ def to_sql(
495505
# try to infer from a non-str expression or if not possible fallback to
496506
# the default pretty dialect for expressions
497507
if dialect is None:
498-
try:
499-
compiler_provider = expr._find_backend(use_default=True)
500-
except com.IbisError:
501-
# default to duckdb for SQL compilation because it supports the
502-
# widest array of ibis features for SQL backends
503-
compiler_provider = sc.duckdb
508+
if isinstance(x, dt.DataType):
509+
compiler_provider = ibis.get_backend()
510+
else:
511+
try:
512+
compiler_provider = x._find_backend(use_default=True)
513+
except com.IbisError:
514+
# default to duckdb for SQL compilation because it supports the
515+
# widest array of ibis features for SQL backends
516+
compiler_provider = sc.duckdb
504517
else:
505518
try:
506519
compiler_provider = getattr(sc, dialect)
507520
except AttributeError as e:
508521
raise ValueError(f"Unknown dialect {dialect}") from e
509522

523+
compiler: SQLGlotCompiler
510524
if (compiler := getattr(compiler_provider, "compiler", None)) is None:
511525
raise NotImplementedError(f"{compiler_provider} is not a SQL backend")
512526

513-
out = compiler.to_sqlglot(expr.unbind(), **kwargs)
514-
queries = out if isinstance(out, list) else [out]
515527
dialect = compiler.dialect
528+
if isinstance(x, dt.DataType):
529+
# kwargs are ignored. Perhaps we should raise an error here if they are passed?
530+
return compiler.to_sqlglot(x).sql(dialect=dialect, pretty=pretty)
531+
532+
out = compiler.to_sqlglot(x.unbind(), **kwargs)
533+
queries = out if isinstance(out, list) else [out]
516534
sql = ";\n".join(query.sql(dialect=dialect, pretty=pretty) for query in queries)
517535
return SQLString(sql)

0 commit comments

Comments
 (0)