Skip to content

Commit 3e06c79

Browse files
committed
fix(databricks/pyspark): unify timestamp/timestamp_ntz behavior
1 parent 66ab2a5 commit 3e06c79

File tree

4 files changed

+41
-7
lines changed

4 files changed

+41
-7
lines changed

ibis/backends/pyspark/datatypes.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from functools import partial
4+
35
import pyspark
46
import pyspark.sql.types as pt
57
from packaging.version import parse as vparse
@@ -27,9 +29,15 @@
2729
pt.NullType: dt.Null,
2830
pt.ShortType: dt.Int16,
2931
pt.StringType: dt.String,
30-
pt.TimestampType: dt.Timestamp,
3132
}
3233

34+
try:
35+
_from_pyspark_dtypes[pt.TimestampNTZType] = dt.Timestamp
36+
except AttributeError:
37+
_from_pyspark_dtypes[pt.TimestampType] = dt.Timestamp
38+
else:
39+
_from_pyspark_dtypes[pt.TimestampType] = partial(dt.Timestamp, timezone="UTC")
40+
3341
_to_pyspark_dtypes = {v: k for k, v in _from_pyspark_dtypes.items()}
3442
_to_pyspark_dtypes[dt.JSON] = pt.StringType
3543
_to_pyspark_dtypes[dt.UUID] = pt.StringType
@@ -54,9 +62,7 @@ def to_ibis(cls, typ, nullable=True):
5462
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
5563
elif isinstance(typ, pt.MapType):
5664
return dt.Map(
57-
cls.to_ibis(typ.keyType),
58-
cls.to_ibis(typ.valueType),
59-
nullable=nullable,
65+
cls.to_ibis(typ.keyType), cls.to_ibis(typ.valueType), nullable=nullable
6066
)
6167
elif isinstance(typ, pt.StructType):
6268
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}
@@ -102,6 +108,11 @@ def from_ibis(cls, dtype):
102108
for n, t in dtype.fields.items()
103109
]
104110
return pt.StructType(fields)
111+
elif dtype.is_timestamp():
112+
if dtype.timezone is not None:
113+
return pt.TimestampType()
114+
else:
115+
return pt.TimestampNTZType()
105116
else:
106117
try:
107118
return _to_pyspark_dtypes[type(dtype)]()

ibis/backends/pyspark/tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,16 @@ def _load_data(self, **_: Any) -> None:
337337

338338
for name, schema in TEST_TABLES.items():
339339
path = str(self.data_dir / "directory" / "parquet" / name)
340+
sch = ibis.schema(
341+
{
342+
col: dtype.copy(timezone="UTC")
343+
if dtype.is_timestamp()
344+
else dtype
345+
for col, dtype in schema.items()
346+
}
347+
)
340348
t = (
341-
s.readStream.schema(PySparkSchema.from_ibis(schema))
349+
s.readStream.schema(PySparkSchema.from_ibis(sch))
342350
.parquet(path)
343351
.repartition(num_partitions)
344352
)

ibis/backends/pyspark/tests/test_window.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_time_indexed_window(t, spark_table, ibis_windows, spark_range):
5050
F.mean(spark_table["value"]).over(spark_window),
5151
).toPandas()
5252

53-
tm.assert_frame_equal(result, expected)
53+
tm.assert_frame_equal(result, expected, check_dtype=False)
5454

5555

5656
@pytest.mark.parametrize(
@@ -90,7 +90,7 @@ def test_multiple_windows(t, spark_table, ibis_windows, spark_range):
9090
)
9191
.toPandas()
9292
)
93-
tm.assert_frame_equal(result, expected)
93+
tm.assert_frame_equal(result, expected, check_dtype=False)
9494

9595

9696
def test_tumble_window_by_grouped_agg(con_streaming, tmp_path):

ibis/backends/tests/test_temporal.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from ibis.common.annotations import ValidationError
4444
from ibis.conftest import IS_SPARK_REMOTE
45+
from ibis.util import gen_name
4546

4647
np = pytest.importorskip("numpy")
4748
pd = pytest.importorskip("pandas")
@@ -2345,3 +2346,17 @@ def test_simple_unix_date_offset(con):
23452346
result = con.execute(expr)
23462347
delta = datetime.date(2023, 4, 7) - datetime.date(1970, 1, 1)
23472348
assert result == delta.days
2349+
2350+
2351+
def test_basic_timestamp_with_timezone(con):
2352+
name = gen_name("tmp_tz")
2353+
ts = "2023-01-07 13:20:05.561021"
2354+
dtype = dt.Timestamp(timezone="UTC")
2355+
colname = "ts"
2356+
result = con.create_table(
2357+
name, ibis.timestamp(ts).cast(dtype).name(colname).as_table()
2358+
)
2359+
try:
2360+
assert result.schema() == ibis.schema({colname: dtype})
2361+
finally:
2362+
con.drop_table(name, force=True)

0 commit comments

Comments
 (0)