Skip to content

fix(databricks/pyspark): unify timestamp/timestamp_ntz behavior #11142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,13 @@ jobs:
- pyspark==3.3.4
- pandas==1.5.3
- numpy==1.23.5
- python-version: "3.9"
pyspark-minor-version: "3.4"
tag: local
deps:
- pyspark==3.4.4
- pandas==1.5.3
- numpy==1.23.5
- python-version: "3.11"
pyspark-minor-version: "3.5"
tag: local
Expand Down Expand Up @@ -609,7 +616,7 @@ jobs:

# it requires a version of pandas that pyspark is not compatible with
- name: remove lonboard
if: matrix.pyspark-minor-version == '3.3'
if: matrix.pyspark-minor-version != '3.5'
run: uv remove --group docs --no-sync lonboard

- name: install pyspark-specific dependencies
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_integer_to_timestamp(case, unit, snapshot):


@pytest.mark.parametrize(
("case",),
"case",
[
param("a\\b\\c", id="escape_backslash"),
param("a\ab\bc\fd\ne\rf\tg\vh", id="escape_ascii_sequences"),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any)

@_read_in_memory.register("ibis.expr.types.Table")
def _table(source, table_name, _conn, **kwargs: Any):
_conn._add_table(table_name, source.to_polars())
_conn._add_table(table_name, _conn.to_polars(source))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a hidden use of the default backend.



@_read_in_memory.register("polars.DataFrame")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def con_for_udf(con, sql_table_setup, sql_define_udf, sql_define_py_udf, test_da
c.execute(sql_table_setup)
c.execute(sql_define_udf)
c.execute(sql_define_py_udf)
yield con
return con


@pytest.fixture
Expand Down
13 changes: 8 additions & 5 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ class SparkConnectGrpcException(Exception):

from ibis.expr.api import Watermark


PYSPARK_VERSION = vparse(pyspark.__version__)
PYSPARK_LT_34 = PYSPARK_VERSION < vparse("3.4")
PYSPARK_LT_35 = PYSPARK_VERSION < vparse("3.5")
PYSPARK_33 = vparse("3.3") <= PYSPARK_VERSION < vparse("3.4")
PYSPARK_35 = vparse("3.5") <= PYSPARK_VERSION
SUPPORTS_TIMESTAMP_NTZ = vparse("3.4") <= PYSPARK_VERSION

ConnectionMode = Literal["streaming", "batch"]


Expand Down Expand Up @@ -244,7 +247,7 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):
if catalog is None and db is None:
yield
return
if catalog is not None and PYSPARK_LT_34:
if catalog is not None and PYSPARK_33:
raise com.UnsupportedArgumentError(
"Catalogs are not supported in pyspark < 3.4"
)
Expand Down Expand Up @@ -313,7 +316,7 @@ def _active_catalog_database(self, catalog: str | None, db: str | None):

@contextlib.contextmanager
def _active_catalog(self, name: str | None):
if name is None or PYSPARK_LT_34:
if name is None or PYSPARK_33:
yield
return

Expand Down Expand Up @@ -408,7 +411,7 @@ def _register_udfs(self, expr: ir.Expr) -> None:
spark_udf = F.udf(udf_func, udf_return)
elif udf.__input_type__ == InputType.PYARROW:
# raise not implemented error if running on pyspark < 3.5
if PYSPARK_LT_35:
if not PYSPARK_35:
raise NotImplementedError(
"pyarrow UDFs are only supported in pyspark >= 3.5"
)
Expand Down
89 changes: 52 additions & 37 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from __future__ import annotations

import pyspark
from functools import partial
from inspect import isclass

import pyspark.sql.types as pt
from packaging.version import parse as vparse

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

# DayTimeIntervalType introduced in Spark 3.2 (at least) but didn't show up in
# PySpark until version 3.3
PYSPARK_33 = vparse(pyspark.__version__) >= vparse("3.3")
PYSPARK_35 = vparse(pyspark.__version__) >= vparse("3.5")


_from_pyspark_dtypes = {
pt.BinaryType: dt.Binary,
pt.BooleanType: dt.Boolean,
Expand All @@ -27,52 +22,64 @@
pt.NullType: dt.Null,
pt.ShortType: dt.Int16,
pt.StringType: dt.String,
pt.TimestampType: dt.Timestamp,
}

_to_pyspark_dtypes = {v: k for k, v in _from_pyspark_dtypes.items()}
try:
_from_pyspark_dtypes[pt.TimestampNTZType] = dt.Timestamp
except AttributeError:
_from_pyspark_dtypes[pt.TimestampType] = dt.Timestamp

Check warning on line 30 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L29-L30

Added lines #L29 - L30 were not covered by tests
else:
_from_pyspark_dtypes[pt.TimestampType] = partial(dt.Timestamp, timezone="UTC")

_to_pyspark_dtypes = {
v: k
for k, v in _from_pyspark_dtypes.items()
if isclass(v) and not issubclass(v, dt.Timestamp) and not isinstance(v, partial)
}
_to_pyspark_dtypes[dt.JSON] = pt.StringType
_to_pyspark_dtypes[dt.UUID] = pt.StringType


if PYSPARK_33:
_pyspark_interval_units = {
pt.DayTimeIntervalType.SECOND: "s",
pt.DayTimeIntervalType.MINUTE: "m",
pt.DayTimeIntervalType.HOUR: "h",
pt.DayTimeIntervalType.DAY: "D",
}


class PySparkType(TypeMapper):
@classmethod
def to_ibis(cls, typ, nullable=True):
"""Convert a pyspark type to an ibis type."""
from ibis.backends.pyspark import SUPPORTS_TIMESTAMP_NTZ

if isinstance(typ, pt.DecimalType):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, pt.ArrayType):
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
elif isinstance(typ, pt.MapType):
return dt.Map(
cls.to_ibis(typ.keyType),
cls.to_ibis(typ.valueType),
nullable=nullable,
cls.to_ibis(typ.keyType), cls.to_ibis(typ.valueType), nullable=nullable
)
elif isinstance(typ, pt.StructType):
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}

return dt.Struct(fields, nullable=nullable)
elif PYSPARK_33 and isinstance(typ, pt.DayTimeIntervalType):
elif isinstance(typ, pt.DayTimeIntervalType):
pyspark_interval_units = {
pt.DayTimeIntervalType.SECOND: "s",
pt.DayTimeIntervalType.MINUTE: "m",
pt.DayTimeIntervalType.HOUR: "h",
pt.DayTimeIntervalType.DAY: "D",
}

if (
typ.startField == typ.endField
and typ.startField in _pyspark_interval_units
and typ.startField in pyspark_interval_units
):
unit = _pyspark_interval_units[typ.startField]
unit = pyspark_interval_units[typ.startField]

Check warning on line 73 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L73

Added line #L73 was not covered by tests
return dt.Interval(unit, nullable=nullable)
else:
raise com.IbisTypeError(f"{typ!r} couldn't be converted to Interval")
elif PYSPARK_35 and isinstance(typ, pt.TimestampNTZType):
return dt.Timestamp(nullable=nullable)
elif isinstance(typ, pt.TimestampNTZType):
if SUPPORTS_TIMESTAMP_NTZ:
return dt.Timestamp(nullable=nullable)
raise com.UnsupportedBackendType(

Check warning on line 80 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L80

Added line #L80 was not covered by tests
"PySpark<3.4 doesn't properly support timestamps without a timezone"
)
elif isinstance(typ, pt.UserDefinedType):
return cls.to_ibis(typ.sqlType(), nullable=nullable)
else:
Expand All @@ -85,6 +92,8 @@

@classmethod
def from_ibis(cls, dtype):
from ibis.backends.pyspark import SUPPORTS_TIMESTAMP_NTZ

if dtype.is_decimal():
return pt.DecimalType(dtype.precision, dtype.scale)
elif dtype.is_array():
Expand All @@ -97,11 +106,21 @@
value_contains_null = dtype.value_type.nullable
return pt.MapType(key_type, value_type, value_contains_null)
elif dtype.is_struct():
fields = [
pt.StructField(n, cls.from_ibis(t), t.nullable)
for n, t in dtype.fields.items()
]
return pt.StructType(fields)
return pt.StructType(
[
pt.StructField(field, cls.from_ibis(dtype), dtype.nullable)
for field, dtype in dtype.fields.items()
]
)
elif dtype.is_timestamp():
if dtype.timezone is not None:
return pt.TimestampType()
else:
if not SUPPORTS_TIMESTAMP_NTZ:
raise com.UnsupportedBackendType(
"PySpark<3.4 doesn't properly support timestamps without a timezone"
)
return pt.TimestampNTZType()
else:
try:
return _to_pyspark_dtypes[type(dtype)]()
Expand All @@ -114,11 +133,7 @@
class PySparkSchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema):
fields = [
pt.StructField(name, PySparkType.from_ibis(dtype), dtype.nullable)
for name, dtype in schema.items()
]
return pt.StructType(fields)
return PySparkType.from_ibis(schema.as_struct())

@classmethod
def to_ibis(cls, schema):
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,16 @@ def _load_data(self, **_: Any) -> None:

for name, schema in TEST_TABLES.items():
path = str(self.data_dir / "directory" / "parquet" / name)
sch = ibis.schema(
{
col: dtype.copy(timezone="UTC")
if dtype.is_timestamp()
else dtype
for col, dtype in schema.items()
}
)
t = (
s.readStream.schema(PySparkSchema.from_ibis(schema))
s.readStream.schema(PySparkSchema.from_ibis(sch))
.parquet(path)
.repartition(num_partitions)
)
Expand Down
41 changes: 21 additions & 20 deletions ibis/backends/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytest import param

import ibis
from ibis.common.exceptions import IbisTypeError
import ibis.common.exceptions as com

pyspark = pytest.importorskip("pyspark")

Expand Down Expand Up @@ -119,30 +119,31 @@ def test_alias_after_select(t):


def test_interval_columns_invalid(con):
df_interval_invalid = con._session.createDataFrame(
[[timedelta(days=10, hours=10, minutes=10, seconds=10)]],
pt.StructType(
[
pt.StructField(
"interval_day_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
),
)
]
),
data = [[timedelta(days=10, hours=10, minutes=10, seconds=10)]]
schema = pt.StructType(
[
pt.StructField(
"interval_day_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
),
)
]
)

df_interval_invalid.createTempView("invalid_interval_table")
msg = r"DayTimeIntervalType.+ couldn't be converted to Interval"
with pytest.raises(IbisTypeError, match=msg):
con.table("invalid_interval_table")
name = "invalid_interval_table"

con._session.createDataFrame(data, schema).createTempView(name)

with pytest.raises(
com.IbisTypeError, match="DayTimeIntervalType.+ couldn't be converted"
):
con.table(name)


def test_string_literal_backslash_escaping(con):
expr = ibis.literal("\\d\\e")
result = con.execute(expr)
assert result == "\\d\\e"
input = r"\d\e"
assert con.execute(ibis.literal(input)) == input


def test_connect_without_explicit_session():
Expand Down
22 changes: 12 additions & 10 deletions ibis/backends/pyspark/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ibis
from ibis import util
from ibis.backends.pyspark import PYSPARK_33
from ibis.backends.tests.errors import PySparkAnalysisException
from ibis.tests.util import assert_equal

Expand Down Expand Up @@ -92,12 +93,13 @@ def test_ctas_from_table_expr(con, alltypes, temp_table_db):

def test_create_empty_table(con, temp_table):
schema = ibis.schema(
[
("a", "string"),
("b", "timestamp"),
("c", "decimal(12, 8)"),
("d", "double"),
]
{
"a": "string",
"b": "timestamp('UTC')",
"c": "decimal(12, 8)",
"d": "double",
}
| ({"e": "timestamp"} if not PYSPARK_33 else {})
)

con.create_table(temp_table, schema=schema)
Expand Down Expand Up @@ -181,9 +183,9 @@ def test_create_table_reserved_identifier(con, alltypes, keyword_t):


@pytest.mark.xfail_version(
pyspark=["pyspark<3.5"],
pyspark=["pyspark<3.4"],
raises=ValueError,
reason="PySparkAnalysisException is not available in PySpark <3.5",
reason="PySparkAnalysisException is not available in PySpark <3.4",
)
def test_create_database_exists(con):
con.create_database(dbname := util.gen_name("dbname"))
Expand All @@ -197,9 +199,9 @@ def test_create_database_exists(con):


@pytest.mark.xfail_version(
pyspark=["pyspark<3.5"],
pyspark=["pyspark<3.4"],
raises=ValueError,
reason="PySparkAnalysisException is not available in PySpark <3.5",
reason="PySparkAnalysisException is not available in PySpark <3.4",
)
def test_drop_database_exists(con):
con.create_database(dbname := util.gen_name("dbname"))
Expand Down
Loading
Loading