Skip to content

Commit 479f888

Browse files
committed
fix(databricks/pyspark): unify timestamp/timestamp_ntz behavior
1 parent 66ab2a5 commit 479f888

File tree

6 files changed

+82
-19
lines changed

6 files changed

+82
-19
lines changed

ibis/backends/polars/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any)
559559

560560
@_read_in_memory.register("ibis.expr.types.Table")
561561
def _table(source, table_name, _conn, **kwargs: Any):
562-
_conn._add_table(table_name, source.to_polars())
562+
_conn._add_table(table_name, _conn.to_polars(source))
563563

564564

565565
@_read_in_memory.register("polars.DataFrame")

ibis/backends/pyspark/datatypes.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from functools import partial
4+
from inspect import isclass
5+
36
import pyspark
47
import pyspark.sql.types as pt
58
from packaging.version import parse as vparse
@@ -27,10 +30,20 @@
2730
pt.NullType: dt.Null,
2831
pt.ShortType: dt.Int16,
2932
pt.StringType: dt.String,
30-
pt.TimestampType: dt.Timestamp,
3133
}
3234

33-
_to_pyspark_dtypes = {v: k for k, v in _from_pyspark_dtypes.items()}
35+
try:
36+
_from_pyspark_dtypes[pt.TimestampNTZType] = dt.Timestamp
37+
except AttributeError:
38+
_from_pyspark_dtypes[pt.TimestampType] = dt.Timestamp
39+
else:
40+
_from_pyspark_dtypes[pt.TimestampType] = partial(dt.Timestamp, timezone="UTC")
41+
42+
_to_pyspark_dtypes = {
43+
v: k
44+
for k, v in _from_pyspark_dtypes.items()
45+
if isclass(v) and not issubclass(v, dt.Timestamp) and not isinstance(v, partial)
46+
}
3447
_to_pyspark_dtypes[dt.JSON] = pt.StringType
3548
_to_pyspark_dtypes[dt.UUID] = pt.StringType
3649

@@ -54,9 +67,7 @@ def to_ibis(cls, typ, nullable=True):
5467
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
5568
elif isinstance(typ, pt.MapType):
5669
return dt.Map(
57-
cls.to_ibis(typ.keyType),
58-
cls.to_ibis(typ.valueType),
59-
nullable=nullable,
70+
cls.to_ibis(typ.keyType), cls.to_ibis(typ.valueType), nullable=nullable
6071
)
6172
elif isinstance(typ, pt.StructType):
6273
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}
@@ -97,11 +108,17 @@ def from_ibis(cls, dtype):
97108
value_contains_null = dtype.value_type.nullable
98109
return pt.MapType(key_type, value_type, value_contains_null)
99110
elif dtype.is_struct():
100-
fields = [
101-
pt.StructField(n, cls.from_ibis(t), t.nullable)
102-
for n, t in dtype.fields.items()
103-
]
104-
return pt.StructType(fields)
111+
return pt.StructType(
112+
[
113+
pt.StructField(field, cls.from_ibis(dtype), dtype.nullable)
114+
for field, dtype in dtype.fields.items()
115+
]
116+
)
117+
elif dtype.is_timestamp():
118+
if dtype.timezone is not None:
119+
return pt.TimestampType()
120+
else:
121+
return pt.TimestampNTZType()
105122
else:
106123
try:
107124
return _to_pyspark_dtypes[type(dtype)]()
@@ -114,11 +131,7 @@ def from_ibis(cls, dtype):
114131
class PySparkSchema(SchemaMapper):
115132
@classmethod
116133
def from_ibis(cls, schema):
117-
fields = [
118-
pt.StructField(name, PySparkType.from_ibis(dtype), dtype.nullable)
119-
for name, dtype in schema.items()
120-
]
121-
return pt.StructType(fields)
134+
return PySparkType.from_ibis(schema.as_struct())
122135

123136
@classmethod
124137
def to_ibis(cls, schema):

ibis/backends/pyspark/tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,16 @@ def _load_data(self, **_: Any) -> None:
337337

338338
for name, schema in TEST_TABLES.items():
339339
path = str(self.data_dir / "directory" / "parquet" / name)
340+
sch = ibis.schema(
341+
{
342+
col: dtype.copy(timezone="UTC")
343+
if dtype.is_timestamp()
344+
else dtype
345+
for col, dtype in schema.items()
346+
}
347+
)
340348
t = (
341-
s.readStream.schema(PySparkSchema.from_ibis(schema))
349+
s.readStream.schema(PySparkSchema.from_ibis(sch))
342350
.parquet(path)
343351
.repartition(num_partitions)
344352
)

ibis/backends/pyspark/tests/test_window.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_time_indexed_window(t, spark_table, ibis_windows, spark_range):
5050
F.mean(spark_table["value"]).over(spark_window),
5151
).toPandas()
5252

53-
tm.assert_frame_equal(result, expected)
53+
tm.assert_frame_equal(result, expected, check_dtype=False)
5454

5555

5656
@pytest.mark.parametrize(
@@ -90,7 +90,7 @@ def test_multiple_windows(t, spark_table, ibis_windows, spark_range):
9090
)
9191
.toPandas()
9292
)
93-
tm.assert_frame_equal(result, expected)
93+
tm.assert_frame_equal(result, expected, check_dtype=False)
9494

9595

9696
def test_tumble_window_by_grouped_agg(con_streaming, tmp_path):

ibis/backends/sql/datatypes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,15 @@ def _from_sqlglot_VARIANT(cls, nullable: bool | None = None) -> sge.DataType:
13551355

13561356
_from_sqlglot_JSON = _from_sqlglot_VARIANT
13571357

1358+
@classmethod
1359+
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
1360+
code = typecode.TIMESTAMPNTZ if dtype.timezone is None else typecode.TIMESTAMPTZ
1361+
if dtype.scale is not None:
1362+
scale = sge.DataTypeParam(this=sge.Literal.number(dtype.scale))
1363+
return sge.DataType(this=code, expressions=[scale])
1364+
else:
1365+
return sge.DataType(this=code)
1366+
13581367

13591368
class AthenaType(SqlglotType):
13601369
dialect = "athena"

ibis/backends/tests/test_temporal.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from ibis.common.annotations import ValidationError
4444
from ibis.conftest import IS_SPARK_REMOTE
45+
from ibis.util import gen_name
4546

4647
np = pytest.importorskip("numpy")
4748
pd = pytest.importorskip("pandas")
@@ -2345,3 +2346,35 @@ def test_simple_unix_date_offset(con):
23452346
result = con.execute(expr)
23462347
delta = datetime.date(2023, 4, 7) - datetime.date(1970, 1, 1)
23472348
assert result == delta.days
2349+
2350+
2351+
timestamp_with_timezone_params = {
2352+
"clickhouse": ("UTC", 0),
2353+
"datafusion": ("+00:00", 9),
2354+
"duckdb": ("UTC", 6),
2355+
"impala": (None, None),
2356+
"oracle": ("UTC", 6),
2357+
"trino": ("UTC", 3),
2358+
}
2359+
2360+
2361+
@pytest.mark.notyet(
2362+
["druid"],
2363+
raises=NotImplementedError,
2364+
reason="druid doesn't implement `create_table`",
2365+
)
2366+
def test_basic_timestamp_with_timezone(con):
2367+
name = gen_name("tmp_tz")
2368+
ts = "2023-01-07 13:20:05.561021"
2369+
dtype = dt.Timestamp(timezone="UTC")
2370+
colname = "ts"
2371+
timezone, scale = timestamp_with_timezone_params.get(con.name, ("UTC", None))
2372+
result = con.create_table(
2373+
name, ibis.timestamp(ts).cast(dtype).name(colname).as_table()
2374+
)
2375+
try:
2376+
assert result.schema() == ibis.schema(
2377+
{colname: dtype.copy(timezone=timezone, scale=scale)}
2378+
)
2379+
finally:
2380+
con.drop_table(name, force=True)

0 commit comments

Comments
 (0)