Skip to content

Commit 4c79ef2

Browse files
authored
fix(python): fix struct dataset (#5798)
1 parent 0bc2768 commit 4c79ef2

File tree

5 files changed

+79
-6
lines changed

5 files changed

+79
-6
lines changed

polars/polars-core/src/chunked_array/logical/struct_/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::collections::BTreeMap;
44

55
use super::*;
66
use crate::datatypes::*;
7+
use crate::utils::index_to_chunked_index2;
78

89
/// This is logical type [`StructChunked`] that
910
/// dispatches most logic to the `fields` implementations
@@ -191,13 +192,14 @@ impl LogicalType for StructChunked {
191192

192193
/// Gets AnyValue from LogicalType
193194
fn get_any_value(&self, i: usize) -> AnyValue<'_> {
195+
let (chunk_idx, idx) = index_to_chunked_index2(&self.chunks, i);
194196
if let DataType::Struct(flds) = self.dtype() {
195197
// safety: we already have a single chunk and we are
196198
// guarded by the type system.
197199
unsafe {
198-
let arr = &**self.chunks.get_unchecked(0);
200+
let arr = &**self.chunks.get_unchecked(chunk_idx);
199201
let arr = &*(arr as *const dyn Array as *const StructArray);
200-
AnyValue::Struct(i, arr, flds)
202+
AnyValue::Struct(idx, arr, flds)
201203
}
202204
} else {
203205
unreachable!()

polars/polars-core/src/utils/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,22 @@ pub(crate) fn index_to_chunked_index<
841841
(current_chunk_idx, index_remainder)
842842
}
843843

844+
#[cfg(feature = "dtype-struct")]
845+
pub(crate) fn index_to_chunked_index2(chunks: &[ArrayRef], index: usize) -> (usize, usize) {
846+
let mut index_remainder = index;
847+
let mut current_chunk_idx = 0;
848+
849+
for chunk in chunks {
850+
if chunk.len() > index_remainder {
851+
break;
852+
} else {
853+
index_remainder -= chunk.len();
854+
current_chunk_idx += 1;
855+
}
856+
}
857+
(current_chunk_idx, index_remainder)
858+
}
859+
844860
/// # SAFETY
845861
/// `dst` must be valid for `dst.len()` elements, and `src` and `dst` may not overlap.
846862
#[inline]

py-polars/polars/internals/construction.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,15 @@ def arrow_to_pyseries(name: str, values: pa.Array, rechunk: bool = True) -> PySe
115115
pys = PySeries.from_arrow(name, array)
116116
else:
117117
if array.num_chunks > 1:
118-
it = array.iterchunks()
119-
pys = PySeries.from_arrow(name, next(it))
120-
for a in it:
121-
pys.append(PySeries.from_arrow(name, a))
118+
# somehow going through ffi with a structarray
119+
# returns the first chunk everytime
120+
if isinstance(array.type, pa.StructType):
121+
pys = PySeries.from_arrow(name, array.combine_chunks())
122+
else:
123+
it = array.iterchunks()
124+
pys = PySeries.from_arrow(name, next(it))
125+
for a in it:
126+
pys.append(PySeries.from_arrow(name, a))
122127
elif array.num_chunks == 0:
123128
pys = PySeries.from_arrow(name, pa.array([], array.type))
124129
else:
@@ -816,6 +821,8 @@ def arrow_to_pydf(
816821
# dictionaries cannot be built in different batches (categorical does not allow
817822
# that) so we rechunk them and create them separately.
818823
dictionary_cols = {}
824+
# struct columns don't work properly if they contain multiple chunks.
825+
struct_cols = {}
819826
names = []
820827
for i, column in enumerate(data):
821828
# extract the name before casting
@@ -829,6 +836,9 @@ def arrow_to_pydf(
829836
if pa.types.is_dictionary(column.type):
830837
ps = arrow_to_pyseries(name, column, rechunk)
831838
dictionary_cols[i] = pli.wrap_s(ps)
839+
elif isinstance(column.type, pa.StructType) and column.num_chunks > 1:
840+
ps = arrow_to_pyseries(name, column, rechunk)
841+
struct_cols[i] = pli.wrap_s(ps)
832842
else:
833843
data_dict[name] = column
834844

@@ -850,11 +860,20 @@ def arrow_to_pydf(
850860
if rechunk:
851861
pydf = pydf.rechunk()
852862

863+
reset_order = False
853864
if len(dictionary_cols) > 0:
854865
df = pli.wrap_df(pydf)
855866
df = df.with_columns(
856867
[pli.lit(s).alias(s.name) for s in dictionary_cols.values()]
857868
)
869+
reset_order = True
870+
871+
if len(struct_cols) > 0:
872+
df = pli.wrap_df(pydf)
873+
df = df.with_columns([pli.lit(s).alias(s.name) for s in struct_cols.values()])
874+
reset_order = True
875+
876+
if reset_order:
858877
df = df[names]
859878
pydf = df._df
860879

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import typing
3+
4+
import pyarrow.dataset as ds
5+
6+
import polars as pl
7+
8+
9+
@typing.no_type_check
10+
def test_struct_pyarrow_dataset_5796() -> None:
11+
if os.name != "nt":
12+
num_rows = 2**17 + 1
13+
14+
df = pl.from_records(
15+
[
16+
dict( # noqa: C408
17+
id=i,
18+
nested=dict( # noqa: C408
19+
a=i,
20+
),
21+
)
22+
for i in range(num_rows)
23+
]
24+
)
25+
26+
df.write_parquet("/tmp/out.parquet", use_pyarrow=True)
27+
tbl = ds.dataset("/tmp/out.parquet").to_table()
28+
assert pl.from_arrow(tbl).frame_equal(df)

py-polars/tests/unit/test_struct.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,11 @@ def test_concat_list_reverse_struct_fields() -> None:
696696
assert df.select(pl.concat_list(["combo", "reverse_combo"])).frame_equal(
697697
df.select(pl.concat_list(["combo", "combo"]))
698698
)
699+
700+
701+
def test_struct_any_value_get_after_append() -> None:
702+
a = pl.Series("a", [{"a": 1, "b": 2}])
703+
b = pl.Series("a", [{"a": 2, "b": 3}])
704+
a = a.append(b)
705+
assert a[0] == {"a": 1, "b": 2}
706+
assert a[1] == {"a": 2, "b": 3}

0 commit comments

Comments
 (0)