Skip to content

Commit db0c741

Browse files
authored
fix(rust, python): fix categorical in struct anyvalue issue (#5987)
1 parent 1b04b36 commit db0c741

File tree

11 files changed

+114
-21
lines changed

11 files changed

+114
-21
lines changed

polars/polars-core/src/chunked_array/logical/categorical/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod stringcache;
77
pub use builder::*;
88
pub(crate) use merge::*;
99
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};
10+
use polars_utils::sync::SyncPtr;
1011

1112
use super::*;
1213
use crate::prelude::*;
@@ -147,7 +148,7 @@ impl LogicalType for CategoricalChunked {
147148

148149
unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
149150
match self.logical.0.get_unchecked(i) {
150-
Some(i) => AnyValue::Categorical(i, self.get_rev_map()),
151+
Some(i) => AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()),
151152
None => AnyValue::Null,
152153
}
153154
}
@@ -295,7 +296,7 @@ mod test {
295296
);
296297
assert!(matches!(
297298
s.get(0)?,
298-
AnyValue::Categorical(0, RevMapping::Local(_))
299+
AnyValue::Categorical(0, RevMapping::Local(_), _)
299300
));
300301

301302
let groups = s.group_tuples(false, true);

polars/polars-core/src/chunked_array/ops/any_value.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::convert::TryFrom;
22

3+
#[cfg(feature = "dtype-categorical")]
4+
use polars_utils::sync::SyncPtr;
5+
36
#[cfg(feature = "object")]
47
use crate::chunked_array::object::extension::polars_extension::PolarsExtension;
58
use crate::prelude::*;
@@ -70,7 +73,7 @@ pub(crate) unsafe fn arr_to_any_value<'a>(
7073
DataType::Categorical(rev_map) => {
7174
let arr = &*(arr as *const dyn Array as *const UInt32Array);
7275
let v = arr.value_unchecked(idx);
73-
AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref())
76+
AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null())
7477
}
7578
#[cfg(feature = "dtype-struct")]
7679
DataType::Struct(flds) => {
@@ -120,12 +123,28 @@ impl<'a> AnyValue<'a> {
120123
let idx = *idx;
121124
unsafe {
122125
arr.values().iter().zip(*flds).map(move |(arr, fld)| {
123-
// TODO! this is hacky. Investigate if we only should put physical types
124-
// into structs
125-
if let Some(arr) = arr.as_any().downcast_ref::<DictionaryArray<u32>>() {
126-
let keys = arr.keys();
127-
arr_to_any_value(keys, idx, fld.data_type())
128-
} else {
126+
// The dictionary arrays categories don't have to map to the rev-map in the dtype
127+
// so we set the array pointer with values of the dictionary array.
128+
#[cfg(feature = "dtype-categorical")]
129+
{
130+
if let Some(arr) = arr.as_any().downcast_ref::<DictionaryArray<u32>>() {
131+
let keys = arr.keys();
132+
let values = arr.values();
133+
let values =
134+
values.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
135+
let arr = &*(keys as *const dyn Array as *const UInt32Array);
136+
let v = arr.value_unchecked(idx);
137+
let DataType::Categorical(Some(rev_map)) = fld.data_type() else {
138+
unimplemented!()
139+
};
140+
AnyValue::Categorical(v, rev_map, SyncPtr::from_const(values))
141+
} else {
142+
arr_to_any_value(&**arr, idx, fld.data_type())
143+
}
144+
}
145+
146+
#[cfg(not(feature = "dtype-categorical"))]
147+
{
129148
arr_to_any_value(&**arr, idx, fld.data_type())
130149
}
131150
})

polars/polars-core/src/datatypes/any_value.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use arrow::types::PrimitiveType;
2+
#[cfg(feature = "dtype-categorical")]
3+
use polars_utils::sync::SyncPtr;
24
use polars_utils::unwrap::UnwrapUncheckedRelease;
35

46
use super::*;
@@ -56,7 +58,9 @@ pub enum AnyValue<'a> {
5658
#[cfg(feature = "dtype-time")]
5759
Time(i64),
5860
#[cfg(feature = "dtype-categorical")]
59-
Categorical(u32, &'a RevMapping),
61+
// If syncptr is_null the data is in the rev-map
62+
// otherwise it is in the array pointer
63+
Categorical(u32, &'a RevMapping, SyncPtr<Utf8Array<i64>>),
6064
/// Nested type, contains arrays that are filled with one of the datetypes.
6165
List(Series),
6266
#[cfg(feature = "object")]
@@ -357,7 +361,7 @@ impl<'a> AnyValue<'a> {
357361
Boolean(_) => DataType::Boolean,
358362
Utf8(_) => DataType::Utf8,
359363
#[cfg(feature = "dtype-categorical")]
360-
Categorical(_, _) => DataType::Categorical(None),
364+
Categorical(_, _, _) => DataType::Categorical(None),
361365
List(s) => DataType::List(Box::new(s.dtype().clone())),
362366
#[cfg(feature = "dtype-struct")]
363367
Struct(_, _, fields) => DataType::Struct(fields.to_vec()),
@@ -616,7 +620,7 @@ impl PartialEq for AnyValue<'_> {
616620
// should it?
617621
(Null, Null) => true,
618622
#[cfg(feature = "dtype-categorical")]
619-
(Categorical(idx_l, rev_l), Categorical(idx_r, rev_r)) => match (rev_l, rev_r) {
623+
(Categorical(idx_l, rev_l, _), Categorical(idx_r, rev_r, _)) => match (rev_l, rev_r) {
620624
(RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => {
621625
id_l == id_r && idx_l == idx_r
622626
}

polars/polars-core/src/fmt.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,12 @@ impl Display for AnyValue<'_> {
746746
write!(f, "{nt}")
747747
}
748748
#[cfg(feature = "dtype-categorical")]
749-
AnyValue::Categorical(idx, rev) => {
750-
let s = rev.get(*idx);
749+
AnyValue::Categorical(idx, rev, arr) => {
750+
let s = if arr.is_null() {
751+
rev.get(*idx)
752+
} else {
753+
unsafe { arr.deref_unchecked().value(*idx as usize) }
754+
};
751755
write!(f, "\"{s}\"")
752756
}
753757
AnyValue::List(s) => write!(f, "{}", s.fmt_list()),

polars/polars-core/src/series/any_value.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,15 @@ impl<'a> From<&AnyValue<'a>> for DataType {
237237
Int8(_) => DataType::Int8,
238238
Int16(_) => DataType::Int16,
239239
#[cfg(feature = "dtype-categorical")]
240-
Categorical(_, rev_map) => DataType::Categorical(Some(Arc::new((*rev_map).clone()))),
240+
Categorical(_, rev_map, arr) => {
241+
if arr.is_null() {
242+
DataType::Categorical(Some(Arc::new((*rev_map).clone())))
243+
} else {
244+
let array = unsafe { arr.deref_unchecked().clone() };
245+
let rev_map = RevMapping::Local(array);
246+
DataType::Categorical(Some(Arc::new(rev_map)))
247+
}
248+
}
241249
#[cfg(feature = "object")]
242250
Object(o) => DataType::Object(o.type_name()),
243251
#[cfg(feature = "object")]

polars/polars-core/src/series/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,13 @@ impl Series {
833833
AnyValue::Utf8(s) => Cow::Borrowed(s),
834834
AnyValue::Null => Cow::Borrowed("null"),
835835
#[cfg(feature = "dtype-categorical")]
836-
AnyValue::Categorical(idx, rev) => Cow::Borrowed(rev.get(idx)),
836+
AnyValue::Categorical(idx, rev, arr) => {
837+
if arr.is_null() {
838+
Cow::Borrowed(rev.get(idx))
839+
} else {
840+
unsafe { Cow::Borrowed(arr.deref_unchecked().value(idx as usize)) }
841+
}
842+
}
837843
av => Cow::Owned(format!("{av}")),
838844
};
839845
Ok(out)

polars/polars-io/src/csv/write_impl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
6969
AnyValue::Boolean(v) => write!(f, "{v}"),
7070
AnyValue::Utf8(v) => fmt_and_escape_str(f, v, options),
7171
#[cfg(feature = "dtype-categorical")]
72-
AnyValue::Categorical(idx, rev_map) => {
72+
AnyValue::Categorical(idx, rev_map, _) => {
7373
let v = rev_map.get(idx);
7474
fmt_and_escape_str(f, v, options)
7575
}

polars/polars-lazy/polars-plan/src/logical_plan/lit.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,17 @@ impl TryFrom<AnyValue<'_>> for LiteralValue {
201201
AnyValue::List(l) => Ok(Self::Series(SpecialEq::new(l))),
202202
AnyValue::Utf8Owned(o) => Ok(Self::Utf8(o.into())),
203203
#[cfg(feature = "dtype-categorical")]
204-
AnyValue::Categorical(c, rev_mapping) => Ok(Self::Utf8(rev_mapping.get(c).to_string())),
204+
AnyValue::Categorical(c, rev_mapping, arr) => {
205+
if arr.is_null() {
206+
Ok(Self::Utf8(rev_mapping.get(c).to_string()))
207+
} else {
208+
unsafe {
209+
Ok(Self::Utf8(
210+
arr.deref_unchecked().value(c as usize).to_string(),
211+
))
212+
}
213+
}
214+
}
205215
_ => Err(PolarsError::ComputeError(
206216
"Unsupported AnyValue type variant, cannot convert to Literal".into(),
207217
)),

polars/polars-utils/src/sync.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Utility that allows use to send pointers to another thread.
22
/// This is better than going through `usize` as MIRI can follow these.
3-
#[derive(Copy, Clone)]
3+
#[derive(Copy, Clone, Debug)]
4+
#[repr(transparent)]
45
pub struct SyncPtr<T>(*mut T);
56

67
impl<T> SyncPtr<T> {
@@ -12,10 +13,32 @@ impl<T> SyncPtr<T> {
1213
Self(ptr)
1314
}
1415

16+
/// # Safety
17+
///
18+
/// This will make a pointer sync and send.
19+
/// Ensure that you don't break aliasing rules.
20+
pub unsafe fn from_const(ptr: *const T) -> Self {
21+
Self(ptr as *mut T)
22+
}
23+
24+
pub fn new_null() -> Self {
25+
Self(std::ptr::null_mut())
26+
}
27+
1528
#[inline(always)]
1629
pub fn get(self) -> *mut T {
1730
self.0
1831
}
32+
33+
pub fn is_null(&self) -> bool {
34+
self.0.is_null()
35+
}
36+
37+
/// # Safety
38+
/// Derefs a raw pointer, no guarantees whatsoever.
39+
pub unsafe fn deref_unchecked(&self) -> &'static T {
40+
&*(self.0 as *const T)
41+
}
1942
}
2043

2144
unsafe impl<T> Sync for SyncPtr<T> {}

py-polars/src/conversion.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,12 @@ impl IntoPy<PyObject> for Wrap<AnyValue<'_>> {
200200
AnyValue::Boolean(v) => v.into_py(py),
201201
AnyValue::Utf8(v) => v.into_py(py),
202202
AnyValue::Utf8Owned(v) => v.into_py(py),
203-
AnyValue::Categorical(idx, rev) => {
204-
let s = rev.get(idx);
203+
AnyValue::Categorical(idx, rev, arr) => {
204+
let s = if arr.is_null() {
205+
rev.get(idx)
206+
} else {
207+
unsafe { arr.deref_unchecked().value(idx as usize) }
208+
};
205209
s.into_py(py)
206210
}
207211
AnyValue::Date(v) => {

0 commit comments

Comments
 (0)