Skip to content

Commit b5f72dd

Browse files
fix truncate for timezone-aware timestamps for DuckDB (#2577)
* fix `truncate` for timezone-aware timestamps for DuckDB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 22da333 commit b5f72dd

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

narwhals/_duckdb/expr_dt.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from datetime import datetime
43
from typing import TYPE_CHECKING
54

65
from duckdb import FunctionExpression
@@ -10,6 +9,8 @@
109
from narwhals.utils import not_implemented
1110

1211
if TYPE_CHECKING:
12+
from duckdb import Expression
13+
1314
from narwhals._duckdb.expr import DuckDBExpr
1415

1516

@@ -109,15 +110,19 @@ def total_microseconds(self) -> DuckDBExpr:
109110

110111
def truncate(self, every: str) -> DuckDBExpr:
111112
multiple, unit = parse_interval_string(every)
113+
if multiple != 1:
114+
# https://github.com/duckdb/duckdb/issues/17554
115+
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
116+
raise ValueError(msg)
112117
if unit == "ns":
113118
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
114119
raise NotImplementedError(msg)
115-
every = f"{multiple!s} {UNITS_DICT[unit]}"
116-
return self._compliant_expr._with_callable(
117-
lambda expr: FunctionExpression(
118-
"time_bucket", lit(every), expr, lit(datetime(1970, 1, 1))
119-
)
120-
)
120+
format = lit(UNITS_DICT[unit])
121+
122+
def _truncate(expr: Expression) -> Expression:
123+
return FunctionExpression("date_trunc", format, expr)
124+
125+
return self._compliant_expr._with_callable(_truncate)
121126

122127
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
123128
if time_zone is None:

tests/expr_and_series/dt/truncate_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_truncate_multiples(
108108
every: str,
109109
expected: list[datetime],
110110
) -> None:
111-
if any(x in str(constructor) for x in ("sqlframe", "cudf", "pyspark")):
111+
if any(x in str(constructor) for x in ("sqlframe", "cudf", "pyspark", "duckdb")):
112112
# Reasons:
113113
# - sqlframe: https://github.com/eakmanrq/sqlframe/issues/383
114114
# - cudf: https://github.com/rapidsai/cudf/issues/18654
@@ -194,3 +194,20 @@ def test_pandas_numpy_nat() -> None:
194194
expected = {"a": [datetime(2020, 1, 1), None, datetime(2020, 1, 1)]}
195195
assert_equal_data(result, expected)
196196
assert result.item(1, 0) is pd.NaT
197+
198+
199+
def test_truncate_tz_aware_duckdb() -> None:
200+
pytest.importorskip("duckdb")
201+
pytest.importorskip("zoneinfo")
202+
import duckdb
203+
from zoneinfo import ZoneInfo
204+
205+
duckdb.sql("""set timezone = 'Europe/Amsterdam'""")
206+
rel = duckdb.sql("""select * from values (timestamptz '2020-10-25') df(a)""")
207+
result = nw.from_native(rel).with_columns(a_truncated=nw.col("a").dt.truncate("1mo"))
208+
expected = {
209+
"a": [datetime(2020, 10, 25, tzinfo=ZoneInfo("Europe/Amsterdam"))],
210+
"a_truncated": [datetime(2020, 10, 1, tzinfo=ZoneInfo("Europe/Amsterdam"))],
211+
}
212+
assert_equal_data(result, expected)
213+
duckdb.sql("""set timezone = 'UTC'""")

tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import warnings
7+
from datetime import date, datetime
78
from pathlib import Path
89
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence
910

@@ -116,8 +117,11 @@ def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None:
116117
)
117118
elif pd.isna(lhs):
118119
are_equivalent_values = pd.isna(rhs)
120+
elif type(lhs) is date and type(rhs) is datetime:
121+
are_equivalent_values = datetime(lhs.year, lhs.month, lhs.day) == rhs
119122
else:
120123
are_equivalent_values = lhs == rhs
124+
121125
assert are_equivalent_values, (
122126
f"Mismatch at index {i}: {lhs} != {rhs}\nExpected: {expected}\nGot: {result}"
123127
)

0 commit comments

Comments
 (0)