Skip to content

fix(databricks/pyspark): unify timestamp/timestamp_ntz behavior #11142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,13 @@ jobs:
- pyspark==3.3.4
- pandas==1.5.3
- numpy==1.23.5
- python-version: "3.9"
pyspark-minor-version: "3.4"
tag: local
deps:
- pyspark==3.4.4
- pandas==1.5.3
- numpy==1.23.5
- python-version: "3.11"
pyspark-minor-version: "3.5"
tag: local
Expand Down Expand Up @@ -609,7 +616,7 @@ jobs:

# it requires a version of pandas that pyspark is not compatible with
- name: remove lonboard
if: matrix.pyspark-minor-version == '3.3'
if: matrix.pyspark-minor-version != '3.5'
run: uv remove --group docs --no-sync lonboard

- name: install pyspark-specific dependencies
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_integer_to_timestamp(case, unit, snapshot):


@pytest.mark.parametrize(
("case",),
"case",
[
param("a\\b\\c", id="escape_backslash"),
param("a\ab\bc\fd\ne\rf\tg\vh", id="escape_ascii_sequences"),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any)

@_read_in_memory.register("ibis.expr.types.Table")
def _table(source, table_name, _conn, **kwargs: Any):
_conn._add_table(table_name, source.to_polars())
_conn._add_table(table_name, _conn.to_polars(source))


@_read_in_memory.register("polars.DataFrame")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def con_for_udf(con, sql_table_setup, sql_define_udf, sql_define_py_udf, test_da
c.execute(sql_table_setup)
c.execute(sql_define_udf)
c.execute(sql_define_py_udf)
yield con
return con


@pytest.fixture
Expand Down
13 changes: 8 additions & 5 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@

from ibis.expr.api import Watermark


PYSPARK_VERSION = vparse(pyspark.__version__)
PYSPARK_LT_34 = PYSPARK_VERSION < vparse("3.4")
PYSPARK_LT_35 = PYSPARK_VERSION < vparse("3.5")
PYSPARK_33 = vparse("3.3") <= PYSPARK_VERSION < vparse("3.4")
PYSPARK_35 = vparse("3.5") <= PYSPARK_VERSION
SUPPORTS_TIMESTAMP_NTZ = vparse("3.4") <= PYSPARK_VERSION

ConnectionMode = Literal["streaming", "batch"]


Expand Down Expand Up @@ -243,7 +246,7 @@
if catalog is None and db is None:
yield
return
if catalog is not None and PYSPARK_LT_34:
if catalog is not None and PYSPARK_33:
raise com.UnsupportedArgumentError(
"Catalogs are not supported in pyspark < 3.4"
)
Expand Down Expand Up @@ -312,7 +315,7 @@

@contextlib.contextmanager
def _active_catalog(self, name: str | None):
if name is None or PYSPARK_LT_34:
if name is None or PYSPARK_33:
yield
return

Expand Down Expand Up @@ -407,7 +410,7 @@
spark_udf = F.udf(udf_func, udf_return)
elif udf.__input_type__ == InputType.PYARROW:
# raise not implemented error if running on pyspark < 3.5
if PYSPARK_LT_35:
if not PYSPARK_35:

Check warning on line 413 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L413

Added line #L413 was not covered by tests
raise NotImplementedError(
"pyarrow UDFs are only supported in pyspark >= 3.5"
)
Expand Down
87 changes: 51 additions & 36 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from __future__ import annotations

import pyspark
from functools import partial
from inspect import isclass

import pyspark.sql.types as pt
from packaging.version import parse as vparse

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

# DayTimeIntervalType introduced in Spark 3.2 (at least) but didn't show up in
# PySpark until version 3.3
PYSPARK_33 = vparse(pyspark.__version__) >= vparse("3.3")
PYSPARK_35 = vparse(pyspark.__version__) >= vparse("3.5")


_from_pyspark_dtypes = {
pt.BinaryType: dt.Binary,
pt.BooleanType: dt.Boolean,
Expand All @@ -27,52 +22,64 @@
pt.NullType: dt.Null,
pt.ShortType: dt.Int16,
pt.StringType: dt.String,
pt.TimestampType: dt.Timestamp,
}

_to_pyspark_dtypes = {v: k for k, v in _from_pyspark_dtypes.items()}
try:
_from_pyspark_dtypes[pt.TimestampNTZType] = dt.Timestamp
except AttributeError:
_from_pyspark_dtypes[pt.TimestampType] = dt.Timestamp

Check warning on line 30 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L29-L30

Added lines #L29 - L30 were not covered by tests
else:
_from_pyspark_dtypes[pt.TimestampType] = partial(dt.Timestamp, timezone="UTC")

_to_pyspark_dtypes = {
v: k
for k, v in _from_pyspark_dtypes.items()
if isclass(v) and not issubclass(v, dt.Timestamp) and not isinstance(v, partial)
}
_to_pyspark_dtypes[dt.JSON] = pt.StringType
_to_pyspark_dtypes[dt.UUID] = pt.StringType


if PYSPARK_33:
_pyspark_interval_units = {
pt.DayTimeIntervalType.SECOND: "s",
pt.DayTimeIntervalType.MINUTE: "m",
pt.DayTimeIntervalType.HOUR: "h",
pt.DayTimeIntervalType.DAY: "D",
}


class PySparkType(TypeMapper):
@classmethod
def to_ibis(cls, typ, nullable=True):
"""Convert a pyspark type to an ibis type."""
from ibis.backends.pyspark import PYSPARK_33, SUPPORTS_TIMESTAMP_NTZ

Check warning on line 47 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L47

Added line #L47 was not covered by tests

if isinstance(typ, pt.DecimalType):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, pt.ArrayType):
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
elif isinstance(typ, pt.MapType):
return dt.Map(
cls.to_ibis(typ.keyType),
cls.to_ibis(typ.valueType),
nullable=nullable,
cls.to_ibis(typ.keyType), cls.to_ibis(typ.valueType), nullable=nullable
)
elif isinstance(typ, pt.StructType):
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}

return dt.Struct(fields, nullable=nullable)
elif PYSPARK_33 and isinstance(typ, pt.DayTimeIntervalType):
pyspark_interval_units = {

Check warning on line 62 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L62

Added line #L62 was not covered by tests
pt.DayTimeIntervalType.SECOND: "s",
pt.DayTimeIntervalType.MINUTE: "m",
pt.DayTimeIntervalType.HOUR: "h",
pt.DayTimeIntervalType.DAY: "D",
}

if (
typ.startField == typ.endField
and typ.startField in _pyspark_interval_units
and typ.startField in pyspark_interval_units
):
unit = _pyspark_interval_units[typ.startField]
unit = pyspark_interval_units[typ.startField]

Check warning on line 73 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L73

Added line #L73 was not covered by tests
return dt.Interval(unit, nullable=nullable)
else:
raise com.IbisTypeError(f"{typ!r} couldn't be converted to Interval")
elif PYSPARK_35 and isinstance(typ, pt.TimestampNTZType):
return dt.Timestamp(nullable=nullable)
elif isinstance(typ, pt.TimestampNTZType):
if SUPPORTS_TIMESTAMP_NTZ:
return dt.Timestamp(nullable=nullable)
raise com.UnsupportedBackendType(

Check warning on line 80 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L79-L80

Added lines #L79 - L80 were not covered by tests
"PySpark<3.4 doesn't properly support timestamps without a timezone"
)
elif isinstance(typ, pt.UserDefinedType):
return cls.to_ibis(typ.sqlType(), nullable=nullable)
else:
Expand All @@ -85,6 +92,8 @@

@classmethod
def from_ibis(cls, dtype):
from ibis.backends.pyspark import SUPPORTS_TIMESTAMP_NTZ

Check warning on line 95 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L95

Added line #L95 was not covered by tests

if dtype.is_decimal():
return pt.DecimalType(dtype.precision, dtype.scale)
elif dtype.is_array():
Expand All @@ -97,11 +106,21 @@
value_contains_null = dtype.value_type.nullable
return pt.MapType(key_type, value_type, value_contains_null)
elif dtype.is_struct():
fields = [
pt.StructField(n, cls.from_ibis(t), t.nullable)
for n, t in dtype.fields.items()
]
return pt.StructType(fields)
return pt.StructType(

Check warning on line 109 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L109

Added line #L109 was not covered by tests
[
pt.StructField(field, cls.from_ibis(dtype), dtype.nullable)
for field, dtype in dtype.fields.items()
]
)
elif dtype.is_timestamp():
if dtype.timezone is not None:
return pt.TimestampType()

Check warning on line 117 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L117

Added line #L117 was not covered by tests
else:
if not SUPPORTS_TIMESTAMP_NTZ:
raise com.UnsupportedBackendType(

Check warning on line 120 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L120

Added line #L120 was not covered by tests
"PySpark<3.4 doesn't properly support timestamps without a timezone"
)
return pt.TimestampNTZType()

Check warning on line 123 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L123

Added line #L123 was not covered by tests
else:
try:
return _to_pyspark_dtypes[type(dtype)]()
Expand All @@ -114,11 +133,7 @@
class PySparkSchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema):
fields = [
pt.StructField(name, PySparkType.from_ibis(dtype), dtype.nullable)
for name, dtype in schema.items()
]
return pt.StructType(fields)
return PySparkType.from_ibis(schema.as_struct())

Check warning on line 136 in ibis/backends/pyspark/datatypes.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/datatypes.py#L136

Added line #L136 was not covered by tests

@classmethod
def to_ibis(cls, schema):
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,16 @@

for name, schema in TEST_TABLES.items():
path = str(self.data_dir / "directory" / "parquet" / name)
sch = ibis.schema(

Check warning on line 340 in ibis/backends/pyspark/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/conftest.py#L340

Added line #L340 was not covered by tests
{
col: dtype.copy(timezone="UTC")
if dtype.is_timestamp()
else dtype
for col, dtype in schema.items()
}
)
t = (
s.readStream.schema(PySparkSchema.from_ibis(schema))
s.readStream.schema(PySparkSchema.from_ibis(sch))
.parquet(path)
.repartition(num_partitions)
)
Expand Down
43 changes: 22 additions & 21 deletions ibis/backends/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytest import param

import ibis
from ibis.common.exceptions import IbisTypeError
import ibis.common.exceptions as com

Check warning on line 12 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L12

Added line #L12 was not covered by tests

pyspark = pytest.importorskip("pyspark")

Expand Down Expand Up @@ -110,7 +110,7 @@
tm.assert_frame_equal(result, df)


def test_alias_after_select(t, df):
def test_alias_after_select(t):

Check warning on line 113 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L113

Added line #L113 was not covered by tests
# Regression test for issue 2136
table = t[["id"]]
table = table.mutate(id2=table["id"])
Expand All @@ -119,30 +119,31 @@


def test_interval_columns_invalid(con):
df_interval_invalid = con._session.createDataFrame(
[[timedelta(days=10, hours=10, minutes=10, seconds=10)]],
pt.StructType(
[
pt.StructField(
"interval_day_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
),
)
]
),
data = [[timedelta(days=10, hours=10, minutes=10, seconds=10)]]
schema = pt.StructType(

Check warning on line 123 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L122-L123

Added lines #L122 - L123 were not covered by tests
[
pt.StructField(
"interval_day_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND
),
)
]
)

df_interval_invalid.createTempView("invalid_interval_table")
msg = r"DayTimeIntervalType.+ couldn't be converted to Interval"
with pytest.raises(IbisTypeError, match=msg):
con.table("invalid_interval_table")
name = "invalid_interval_table"

Check warning on line 134 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L134

Added line #L134 was not covered by tests

con._session.createDataFrame(data, schema).createTempView(name)

Check warning on line 136 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L136

Added line #L136 was not covered by tests

with pytest.raises(

Check warning on line 138 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L138

Added line #L138 was not covered by tests
com.IbisTypeError, match="DayTimeIntervalType.+ couldn't be converted"
):
con.table(name)

Check warning on line 141 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L141

Added line #L141 was not covered by tests


def test_string_literal_backslash_escaping(con):
expr = ibis.literal("\\d\\e")
result = con.execute(expr)
assert result == "\\d\\e"
input = r"\d\e"
assert con.execute(ibis.literal(input)) == input

Check warning on line 146 in ibis/backends/pyspark/tests/test_basic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_basic.py#L145-L146

Added lines #L145 - L146 were not covered by tests


def test_connect_without_explicit_session():
Expand Down
22 changes: 12 additions & 10 deletions ibis/backends/pyspark/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ibis
from ibis import util
from ibis.backends.pyspark import PYSPARK_33

Check warning on line 12 in ibis/backends/pyspark/tests/test_ddl.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/test_ddl.py#L12

Added line #L12 was not covered by tests
from ibis.backends.tests.errors import PySparkAnalysisException
from ibis.tests.util import assert_equal

Expand Down Expand Up @@ -92,12 +93,13 @@

def test_create_empty_table(con, temp_table):
schema = ibis.schema(
[
("a", "string"),
("b", "timestamp"),
("c", "decimal(12, 8)"),
("d", "double"),
]
{
"a": "string",
"b": "timestamp('UTC')",
"c": "decimal(12, 8)",
"d": "double",
}
| ({"e": "timestamp"} if not PYSPARK_33 else {})
)

con.create_table(temp_table, schema=schema)
Expand Down Expand Up @@ -181,9 +183,9 @@


@pytest.mark.xfail_version(
pyspark=["pyspark<3.5"],
pyspark=["pyspark<3.4"],
raises=ValueError,
reason="PySparkAnalysisException is not available in PySpark <3.5",
reason="PySparkAnalysisException is not available in PySpark <3.4",
)
def test_create_database_exists(con):
con.create_database(dbname := util.gen_name("dbname"))
Expand All @@ -197,9 +199,9 @@


@pytest.mark.xfail_version(
pyspark=["pyspark<3.5"],
pyspark=["pyspark<3.4"],
raises=ValueError,
reason="PySparkAnalysisException is not available in PySpark <3.5",
reason="PySparkAnalysisException is not available in PySpark <3.4",
)
def test_drop_database_exists(con):
con.create_database(dbname := util.gen_name("dbname"))
Expand Down
Loading
Loading