diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index b1649cbfbaa..ae11d288f51 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -113,6 +113,7 @@ def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: Name.IsoYear, Name.MonthStart, Name.MonthEnd, + Name.CastTimeUnit, } def __init__( @@ -148,6 +149,14 @@ def do_evaluate( for child in self.children ] (column,) = columns + if self.name is TemporalFunction.Name.CastTimeUnit: + (unit,) = self.options + if plc.traits.is_timestamp(column.obj.type()): + dtype = plc.interop.from_arrow(pa.timestamp(unit)) + elif plc.traits.is_duration(column.obj.type()): + dtype = plc.interop.from_arrow(pa.duration(unit)) + result = plc.unary.cast(column.obj, dtype) + return Column(result) if self.name == TemporalFunction.Name.ToString: return Column( plc.strings.convert.convert_datetime.from_timestamps( diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py index 187e51fe9af..0c7d5c66e0a 100644 --- a/python/cudf_polars/tests/expressions/test_datetime_basic.py +++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py @@ -296,3 +296,48 @@ def test_isoyear(): q = df.with_columns(pl.col("date").dt.iso_year().alias("isoyear")) assert_gpu_result_equal(q) + + +@pytest.mark.parametrize( + "dtype", [pl.Date(), pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")] +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_cast_time_unit_datetime(dtype, time_unit): + sr = pl.Series( + "date", + [ + datetime.datetime(1970, 1, 1, 0, 0, 0), + datetime.datetime(1999, 12, 31, 23, 59, 59), + datetime.datetime(2001, 1, 1, 12, 0, 0), + datetime.datetime(2020, 2, 29, 23, 59, 59), + datetime.datetime(2024, 12, 31, 23, 59, 59, 999999), + ], + dtype=dtype, + ) + df = pl.DataFrame({"date": sr}).lazy() + + q = df.select(pl.col("date").dt.cast_time_unit(time_unit).alias("time_unit_ms")) + + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize( + "dtype", [pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")] +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_cast_time_unit_duration(dtype, time_unit): + sr = pl.Series( + "date", + [ + datetime.timedelta(days=1), + datetime.timedelta(days=2), + datetime.timedelta(days=3), + datetime.timedelta(days=4), + datetime.timedelta(days=5), + ], + dtype=dtype, + ) + df = pl.DataFrame({"date": sr}).lazy() + + q = df.select(pl.col("date").dt.cast_time_unit(time_unit).alias("time_unit_ms")) + assert_gpu_result_equal(q)