Skip to content

Commit 25fab78

Browse files
authored
fix: Fix rolling aggregations for various integer types (#20512)
1 parent 4c14e70 commit 25fab78

File tree

6 files changed

+29
-7
lines changed

6 files changed

+29
-7
lines changed

Diff for: crates/polars-core/src/frame/group_by/aggregations/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ impl_take_extremum!(i8);
286286
impl_take_extremum!(i16);
287287
impl_take_extremum!(i32);
288288
impl_take_extremum!(i64);
289-
#[cfg(feature = "dtype-decimal")]
289+
#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
290290
impl_take_extremum!(i128);
291291
impl_take_extremum!(float: f32);
292292
impl_take_extremum!(float: f64);

Diff for: crates/polars-core/src/hashing/vector_hasher.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ vec_hash_numeric!(UInt16Chunked);
167167
vec_hash_numeric!(UInt8Chunked);
168168
vec_hash_numeric!(Float64Chunked);
169169
vec_hash_numeric!(Float32Chunked);
170-
#[cfg(feature = "dtype-decimal")]
170+
#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
171171
vec_hash_numeric!(Int128Chunked);
172172

173173
impl VecHash for StringChunked {

Diff for: crates/polars-expr/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dtype-full = [
4949
"dtype-decimal",
5050
"dtype-duration",
5151
"dtype-i16",
52+
"dtype-i128",
5253
"dtype-i8",
5354
"dtype-struct",
5455
"dtype-time",

Diff for: crates/polars-lazy/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ dtype-full = [
9797
"dtype-decimal",
9898
"dtype-duration",
9999
"dtype-i16",
100+
"dtype-i128",
100101
"dtype-i8",
101102
"dtype-struct",
102103
"dtype-time",
@@ -144,6 +145,7 @@ dtype-duration = [
144145
"polars-mem-engine/dtype-duration",
145146
]
146147
dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe?/dtype-i16", "polars-expr/dtype-i16", "polars-mem-engine/dtype-i16"]
148+
dtype-i128 = ["polars-plan/dtype-i128", "polars-pipe?/dtype-i128", "polars-expr/dtype-i128"]
147149
dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe?/dtype-i8", "polars-expr/dtype-i8", "polars-mem-engine/dtype-i8"]
148150
dtype-struct = [
149151
"polars-plan/dtype-struct",

Diff for: crates/polars/Cargo.toml

+11
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ dtype-full = [
265265
"dtype-array",
266266
"dtype-i8",
267267
"dtype-i16",
268+
"dtype-i128",
268269
"dtype-decimal",
269270
"dtype-u8",
270271
"dtype-u16",
@@ -318,12 +319,20 @@ dtype-i8 = [
318319
"polars-io/dtype-i8",
319320
"polars-lazy?/dtype-i8",
320321
"polars-ops/dtype-i8",
322+
"polars-time?/dtype-i8",
321323
]
322324
dtype-i16 = [
323325
"polars-core/dtype-i16",
324326
"polars-io/dtype-i16",
325327
"polars-lazy?/dtype-i16",
326328
"polars-ops/dtype-i16",
329+
"polars-time?/dtype-i16",
330+
]
331+
dtype-i128 = [
332+
"polars-core/dtype-i128",
333+
"polars-lazy?/dtype-i128",
334+
"polars-ops/dtype-i128",
335+
"polars-time?/dtype-i128",
327336
]
328337
dtype-decimal = [
329338
"polars-core/dtype-decimal",
@@ -337,12 +346,14 @@ dtype-u8 = [
337346
"polars-io/dtype-u8",
338347
"polars-lazy?/dtype-u8",
339348
"polars-ops/dtype-u8",
349+
"polars-time?/dtype-u8",
340350
]
341351
dtype-u16 = [
342352
"polars-core/dtype-u16",
343353
"polars-io/dtype-u16",
344354
"polars-lazy?/dtype-u16",
345355
"polars-ops/dtype-u16",
356+
"polars-time?/dtype-u16",
346357
]
347358
dtype-categorical = [
348359
"polars-core/dtype-categorical",

Diff for: py-polars/tests/unit/operations/rolling/test_rolling.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from polars.testing import assert_frame_equal, assert_series_equal
1616
from polars.testing.parametric import column, dataframes
1717
from polars.testing.parametric.strategies.dtype import _time_units
18+
from tests.unit.conftest import INTEGER_DTYPES
1819

1920
if TYPE_CHECKING:
2021
from hypothesis.strategies import SearchStrategy
@@ -739,11 +740,18 @@ def test_rolling_aggregations_with_over_11225() -> None:
739740
assert_frame_equal(result, expected)
740741

741742

742-
def test_rolling() -> None:
743-
s = pl.Series("a", [1, 2, 3, 2, 1])
744-
assert_series_equal(s.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1]))
745-
assert_series_equal(s.rolling_max(2), pl.Series("a", [None, 2, 3, 3, 2]))
746-
assert_series_equal(s.rolling_sum(2), pl.Series("a", [None, 3, 5, 5, 3]))
743+
@pytest.mark.parametrize("dtype", INTEGER_DTYPES)
744+
def test_rolling(dtype: PolarsDataType) -> None:
745+
s = pl.Series("a", [1, 2, 3, 2, 1], dtype=dtype)
746+
assert_series_equal(
747+
s.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1], dtype=dtype)
748+
)
749+
assert_series_equal(
750+
s.rolling_max(2), pl.Series("a", [None, 2, 3, 3, 2], dtype=dtype)
751+
)
752+
assert_series_equal(
753+
s.rolling_sum(2), pl.Series("a", [None, 3, 5, 5, 3], dtype=dtype)
754+
)
747755
assert_series_equal(s.rolling_mean(2), pl.Series("a", [None, 1.5, 2.5, 2.5, 1.5]))
748756

749757
assert s.rolling_std(2).to_list()[1] == pytest.approx(0.7071067811865476)

0 commit comments

Comments
 (0)