Skip to content

feat(dtype): support compiling dtypes to sql #11100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import operator
import string
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -17,6 +17,7 @@
import ibis.common.patterns as pats
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
Expand Down Expand Up @@ -53,7 +54,6 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping

import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.sql.datatypes import SqlglotType

Expand Down Expand Up @@ -575,16 +575,52 @@
result[node] = value
return result

@overload
def to_sqlglot(
self,
expr: ir.Expr,
x: dt.DataType,
*,
params: Mapping[ir.Expr, Any] | None = None,
) -> sge.DataType: ...

@overload
def to_sqlglot(
self,
x: sch.Schema,
*,
params: Mapping[ir.Expr, Any] | None = None,
) -> list[sge.ColumnDef]: ...

@overload
def to_sqlglot(
self,
x: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
) -> sge.Expression: ...

def to_sqlglot(
self,
x: ir.Expr | dt.DataType | sch.Schema,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
) -> sge.Expression:
if isinstance(x, (dt.DataType, sch.Schema)):
return x.to_sqlglot(self.dialect)

Check warning on line 611 in ibis/backends/sql/compilers/base.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/base.py#L611

Added line #L611 was not covered by tests
return self._to_sqlglot_expr(x, limit=limit, params=params)

def _to_sqlglot_expr(
self,
x: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
) -> sge.Expression:
import ibis

table_expr = expr.as_table()
table_expr = x.as_table()

if limit == "default":
limit = ibis.options.sql.default_limit
Expand Down
14 changes: 8 additions & 6 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator
Expand Down Expand Up @@ -206,7 +207,7 @@

def to_sqlglot(
self,
expr: ir.Expr,
expr: ir.Expr | dt.DataType | sch.Schema,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
Expand Down Expand Up @@ -234,17 +235,18 @@
Any
The output of compilation. The type of this value depends on the
backend.

"""
sql = super().to_sqlglot(expr, limit=limit, params=params)
if isinstance(expr, (dt.DataType, sch.Schema)):
return expr.to_sqlglot(self.dialect)

Check warning on line 240 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L240

Added line #L240 was not covered by tests
sgexpr = super().to_sqlglot(expr, limit=limit, params=params)

table_expr = expr.as_table()

memtable_names = frozenset(
op.name for op in table_expr.op().find(ops.InMemoryTable)
)

result = sql.transform(
result = sgexpr.transform(
_qualify_memtable,
dataset=session_dataset_id,
project=session_project,
Expand All @@ -257,8 +259,8 @@
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
sources.append(sql)
if sgexpr := compile_func(udf_node):
sources.append(sgexpr)

if not sources:
return result
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.RandomScalar: "random",
}

def to_sqlglot(
def _to_sqlglot_expr(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
sql = super().to_sqlglot(expr, limit=limit, params=params)
sgexpr = super()._to_sqlglot_expr(expr, limit=limit, params=params)

table_expr = expr.as_table()
geocols = table_expr.schema().geospatial

if not geocols:
return sql
return sgexpr

quoted = self.quoted
return sg.select(
Expand All @@ -132,7 +132,7 @@ def to_sqlglot(
for col in geocols
]
)
).from_(sql.subquery())
).from_(sgexpr.subquery())

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
9 changes: 4 additions & 5 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping

import ibis.expr.operations as ir
from ibis.expr import types as ir

y = var("y")
start = var("start")
Expand Down Expand Up @@ -156,14 +156,13 @@ def _minimize_spec(op, spec):
return None
return spec

def to_sqlglot(
def _to_sqlglot_expr(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
"""Compile an Ibis expression to a sqlglot object."""
) -> sge.Expression:
import ibis

table_expr = expr.as_table()
Expand All @@ -175,7 +174,7 @@ def to_sqlglot(

if conversions:
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)

def visit_RandomScalar(self, op):
# By default RAND() will generate the same value for all calls within a
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ class PostgresCompiler(SQLGlotCompiler):
ops.RandomUUID: "gen_random_uuid",
}

def to_sqlglot(
def _to_sqlglot_expr(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
) -> sg.Expression:
table_expr = expr.as_table()
schema = table_expr.schema()

Expand All @@ -140,7 +140,7 @@ def to_sqlglot(

if conversions:
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)
return super()._to_sqlglot_expr(table_expr, limit=limit, params=params)

def _compile_python_udf(self, udf_node: ops.ScalarUDF):
config = udf_node.__config__
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RisingWaveCompiler(PostgresCompiler):

del SIMPLE_OPS[ops.MapContains]

def to_sqlglot(
def _to_sqlglot_expr(
self,
expr: ir.Expr,
*,
Expand All @@ -79,7 +79,9 @@ def to_sqlglot(

if conversions:
table_expr = table_expr.mutate(**conversions)
return SQLGlotCompiler.to_sqlglot(self, table_expr, limit=limit, params=params)
return SQLGlotCompiler._to_sqlglot_expr(
self, table_expr, limit=limit, params=params
)

def visit_DateNow(self, op):
return self.cast(sge.CurrentTimestamp(), dt.date)
Expand Down
19 changes: 16 additions & 3 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@
@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert an Ibis dtype to an sqlglot dtype."""

if method := getattr(cls, f"_from_ibis_{dtype.name}", None):
return method(dtype)
else:
return sge.DataType(this=_to_sqlglot_types[type(dtype)])
return sge.DataType(
this=_to_sqlglot_types[type(dtype)], nullable=dtype.nullable
)

@classmethod
def from_string(cls, text: str, nullable: bool | None = None) -> dt.DataType:
Expand Down Expand Up @@ -1360,7 +1361,19 @@
dialect = "athena"


TYPE_MAPPERS: dict[str, SqlglotType] = {
_TYPE_MAPPERS: dict[str, type[SqlglotType]] = {
mapper.dialect: mapper
for mapper in set(get_subclasses(SqlglotType)) - {SqlglotType, BigQueryUDFType}
}
_TYPE_MAPPERS["pyspark"] = PySparkType
_TYPE_MAPPERS["druid"] = DruidType


def get_mapper(dialect: str | type[sg.Dialect] | None) -> type[SqlglotType]:
import ibis

if dialect is None:
dialect = ibis.get_backend().dialect

Check warning on line 1376 in ibis/backends/sql/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/datatypes.py#L1376

Added line #L1376 was not covered by tests
if not isinstance(dialect, str):
dialect = dialect.__name__.lower()
return _TYPE_MAPPERS[dialect]
57 changes: 57 additions & 0 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import re

import pytest
Expand Down Expand Up @@ -197,6 +198,62 @@
snapshot.assert_match(ibis.to_sql(expr), "to_sql.sql")


@contextlib.contextmanager
def with_default_backend(backend: str):
original_backend = ibis.get_backend()
try:
ibis.set_backend(backend)
yield

Check warning on line 206 in ibis/backends/tests/test_sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_sql.py#L203-L206

Added lines #L203 - L206 were not covered by tests
finally:
ibis.set_backend(original_backend)

Check warning on line 208 in ibis/backends/tests/test_sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_sql.py#L208

Added line #L208 was not covered by tests


@pytest.mark.parametrize(
"dialect",
[
# Just check these two to make sure that everything is plumbed through
pytest.param("sqlite", marks=pytest.mark.xfail(reason="arrays not supported")),
"duckdb",
],
)
def test_to_sql_dtype_default_backend(dialect):
dt = ibis.dtype("array<int64>")
with with_default_backend(dialect):
sql = ibis.to_sql(dt)
assert "BIGINT[]" == str(sql)

Check warning on line 223 in ibis/backends/tests/test_sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_sql.py#L220-L223

Added lines #L220 - L223 were not covered by tests


STRING_DTYPES = {
"athena": "VARCHAR",
"bigquery": "STRING",
"clickhouse": "Nullable(String)",
"databricks": "STRING",
"datafusion": "VARCHAR",
"druid": "VARCHAR",
"duckdb": "TEXT",
"exasol": "VARCHAR(2000000)",
"flink": "STRING",
"impala": "STRING",
"mssql": "VARCHAR(max)",
"mysql": "TEXT",
"oracle": "VARCHAR2(4000)",
"postgres": "VARCHAR",
"pyspark": "STRING",
"risingwave": "VARCHAR",
"snowflake": "VARCHAR",
"sqlite": "TEXT",
"trino": "VARCHAR",
}


@pytest.mark.parametrize("backend_name", _get_backends_to_test(discard=("polars",)))
def test_to_sql_dtype(backend_name):
dt = ibis.dtype("string")
sql = ibis.to_sql(dt, dialect=backend_name)
expected = STRING_DTYPES[backend_name]
assert expected == str(sql)


@pytest.mark.notimpl(["polars"], raises=ValueError, reason="not a SQL backend")
def test_many_subqueries(backend_name, snapshot):
def query(t, group_cols):
Expand Down
Loading