Skip to content

Commit db580af

Browse files
authored
fix(rust, python): correct invalid type in struct anyvalue access (#5844)
1 parent 36c72c8 commit db580af

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

polars/polars-core/src/chunked_array/ops/any_value.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,16 @@ impl<'a> AnyValue<'a> {
112112
AnyValue::Struct(idx, arr, flds) => {
113113
let idx = *idx;
114114
unsafe {
115-
arr.values()
116-
.iter()
117-
.zip(*flds)
118-
.map(move |(arr, fld)| arr_to_any_value(&**arr, idx, fld.data_type()))
115+
arr.values().iter().zip(*flds).map(move |(arr, fld)| {
116+
// TODO! this is hacky. Investigate if we only should put physical types
117+
// into structs
118+
if let Some(arr) = arr.as_any().downcast_ref::<DictionaryArray<u32>>() {
119+
let keys = arr.keys();
120+
arr_to_any_value(keys, idx, fld.data_type())
121+
} else {
122+
arr_to_any_value(&**arr, idx, fld.data_type())
123+
}
124+
})
119125
}
120126
}
121127
_ => unreachable!(),

polars/polars-ops/src/series/ops/various.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub trait SeriesMethods: SeriesSealed {
1313
let groups = s.group_tuples(multithreaded, sorted)?;
1414
let values = unsafe { s.agg_first(&groups) };
1515
let counts = groups.group_lengths("counts");
16-
let cols = vec![values.into_series(), counts.into_series()];
16+
let cols = vec![values, counts.into_series()];
1717
let df = DataFrame::new_no_checks(cols);
1818
if sorted {
1919
df.sort(["counts"], true)

py-polars/tests/unit/test_struct.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,17 @@ def test_struct_any_value_get_after_append() -> None:
704704
a = a.append(b)
705705
assert a[0] == {"a": 1, "b": 2}
706706
assert a[1] == {"a": 2, "b": 3}
707+
708+
709+
def test_struct_categorical_5843() -> None:
710+
df = pl.DataFrame({"foo": ["a", "b", "c", "a"]}).with_column(
711+
pl.col("foo").cast(pl.Categorical)
712+
)
713+
result = df.select(pl.col("foo").value_counts(sort=True))
714+
assert result.to_dict(False) == {
715+
"foo": [
716+
{"foo": "a", "counts": 2},
717+
{"foo": "b", "counts": 1},
718+
{"foo": "c", "counts": 1},
719+
]
720+
}

0 commit comments

Comments
 (0)