Skip to content

Commit 05447ab

Browse files
authored
feat(rust, python): add arr.take expression (#6116)
1 parent 9565988 commit 05447ab

File tree

18 files changed

+232
-9
lines changed

18 files changed

+232
-9
lines changed

polars/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ list_eval = ["polars-lazy/list_eval"]
124124
cumulative_eval = ["polars-lazy/cumulative_eval"]
125125
chunked_ids = ["polars-core/chunked_ids", "polars-lazy/chunked_ids", "polars-core/chunked_ids"]
126126
to_dummies = ["polars-ops/to_dummies"]
127-
bigidx = ["polars-core/bigidx", "polars-lazy/bigidx"]
127+
bigidx = ["polars-core/bigidx", "polars-lazy/bigidx", "polars-ops/big_idx"]
128128
list_to_struct = ["polars-ops/list_to_struct", "polars-lazy/list_to_struct"]
129+
list_take = ["polars-ops/list_take", "polars-lazy/list_take"]
129130
describe = ["polars-core/describe"]
130131
timezones = ["polars-core/timezones", "polars-lazy/timezones"]
131132
string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"]

polars/polars-arrow/src/index.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,27 @@
22
use arrow::array::UInt32Array;
33
#[cfg(feature = "bigidx")]
44
use arrow::array::UInt64Array;
5+
use num::{NumCast, Signed, Zero};
56

67
pub trait IndexToUsize {
78
/// Translate the negative index to an offset.
8-
fn negative_to_usize(self, index: usize) -> Option<usize>;
9+
fn negative_to_usize(self, len: usize) -> Option<usize>;
910
}
1011

11-
impl IndexToUsize for i64 {
12-
fn negative_to_usize(self, index: usize) -> Option<usize> {
13-
if self >= 0 && (self as usize) < index {
14-
Some(self as usize)
12+
impl<I> IndexToUsize for I
13+
where
14+
I: PartialOrd + PartialEq + NumCast + Signed + Zero,
15+
{
16+
#[inline]
17+
fn negative_to_usize(self, len: usize) -> Option<usize> {
18+
if self >= Zero::zero() && (self.to_usize().unwrap()) < len {
19+
Some(self.to_usize().unwrap())
1520
} else {
16-
let subtract = self.unsigned_abs() as usize;
17-
if subtract > index {
21+
let subtract = self.abs().to_usize().unwrap();
22+
if subtract > len {
1823
None
1924
} else {
20-
Some(index - subtract)
25+
Some(len - subtract)
2126
}
2227
}
2328
}

polars/polars-lazy/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ date_offset = ["polars-plan/date_offset"]
6464
trigonometry = ["polars-plan/trigonometry"]
6565
sign = ["polars-plan/sign"]
6666
timezones = ["polars-plan/timezones"]
67+
list_take = ["polars-ops/list_take", "polars-plan/list_take"]
6768

6869
true_div = ["polars-plan/true_div"]
6970

polars/polars-lazy/polars-plan/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dtype-struct = ["polars-core/dtype-struct"]
5252
dtype-binary = ["polars-core/dtype-binary"]
5353
object = ["polars-core/object"]
5454
date_offset = ["polars-time"]
55+
list_take = ["polars-ops/list_take"]
5556
trigonometry = []
5657
sign = []
5758
timezones = ["polars-time/timezones", "polars-core/timezones"]

polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ pub enum ListFunction {
1010
Contains,
1111
Slice,
1212
Get,
13+
#[cfg(feature = "list_take")]
14+
Take,
1315
}
1416

1517
impl Display for ListFunction {
@@ -22,6 +24,8 @@ impl Display for ListFunction {
2224
Contains => "contains",
2325
Slice => "slice",
2426
Get => "get",
27+
#[cfg(feature = "list_take")]
28+
Take => "take",
2529
};
2630
write!(f, "{name}")
2731
}
@@ -185,3 +189,20 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Series> {
185189

186190
}
187191
}
192+
193+
#[cfg(feature = "list_take")]
194+
pub(super) fn take(args: &[Series]) -> PolarsResult<Series> {
195+
let ca = &args[0];
196+
let idx = &args[1];
197+
let ca = ca.list()?;
198+
199+
if idx.len() == 1 {
200+
// fast path
201+
let idx = idx.get(0)?.try_extract::<i64>()?;
202+
let out = ca.lst_get(idx)?;
203+
// make sure we return a list
204+
out.reshape(&[-1, 1])
205+
} else {
206+
ca.lst_take(idx)
207+
}
208+
}

polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
326326
Contains => wrap!(list::contains),
327327
Slice => wrap!(list::slice),
328328
Get => wrap!(list::get),
329+
#[cfg(feature = "list_take")]
330+
Take => map_as_slice!(list::take),
329331
}
330332
}
331333
#[cfg(feature = "dtype-struct")]

polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ impl FunctionExpr {
184184
Contains => with_dtype(DataType::Boolean),
185185
Slice => same_type(),
186186
Get => inner_type_list(),
187+
#[cfg(feature = "list_take")]
188+
Take => same_type(),
187189
}
188190
}
189191
#[cfg(feature = "dtype-struct")]

polars/polars-lazy/polars-plan/src/dsl/list.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ impl ListNameSpace {
122122
.map_many_private(FunctionExpr::ListExpr(ListFunction::Get), &[index], false)
123123
}
124124

125+
/// Get items in every sublist by multiple indexes.
126+
#[cfg(feature = "list_take")]
127+
#[cfg_attr(docsrs, doc(cfg(feature = "list_take")))]
128+
pub fn take(self, index: Expr) -> Expr {
129+
self.0
130+
.map_many_private(FunctionExpr::ListExpr(ListFunction::Take), &[index], false)
131+
}
132+
125133
/// Get first item of every sublist.
126134
pub fn first(self) -> Expr {
127135
self.get(lit(0i64))

polars/polars-ops/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dtype-i16 = ["polars-core/dtype-i16"]
3535
object = ["polars-core/object"]
3636
propagate_nans = []
3737
performant = ["polars-core/performant"]
38+
big_idx = ["polars-core/bigidx"]
3839

3940
# ops
4041
to_dummies = []
@@ -56,3 +57,4 @@ cross_join = ["polars-core/cross_join"]
5657
chunked_ids = ["polars-core/chunked_ids"]
5758
asof_join = ["polars-core/asof_join"]
5859
semi_anti_join = ["polars-core/semi_anti_join"]
60+
list_take = []

polars/polars-ops/src/chunked_array/list/namespace.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::fmt::Write;
44
use polars_arrow::kernels::list::sublist_get;
55
use polars_arrow::prelude::ValueSize;
66
use polars_core::chunked_array::builder::get_list_builder;
7+
#[cfg(feature = "list_take")]
8+
use polars_core::export::num::{NumCast, Signed, Zero};
79
#[cfg(feature = "diff")]
810
use polars_core::series::ops::NullBehavior;
911
use polars_core::utils::{try_get_supertype, CustomIterTools};
@@ -213,6 +215,72 @@ pub trait ListNameSpaceImpl: AsList {
213215
Series::try_from((ca.name(), chunks))
214216
}
215217

218+
#[cfg(feature = "list_take")]
219+
fn lst_take(&self, idx: &Series) -> PolarsResult<Series> {
220+
let list_ca = self.as_list();
221+
222+
let index_typed_index = |idx: &Series| {
223+
let other = idx.cast(&IDX_DTYPE).unwrap();
224+
let idx = other.idx().unwrap();
225+
list_ca
226+
.amortized_iter()
227+
.map(|s| s.map(|s| s.as_ref().take(idx)).transpose())
228+
.collect::<PolarsResult<ListChunked>>()
229+
.map(|mut ca| {
230+
ca.rename(list_ca.name());
231+
ca.into_series()
232+
})
233+
};
234+
235+
use DataType::*;
236+
match idx.dtype() {
237+
List(_) => {
238+
let idx_ca = idx.list().unwrap();
239+
let mut out = list_ca
240+
.amortized_iter()
241+
.zip(idx_ca.into_iter())
242+
.map(|(opt_s, opt_idx)| {
243+
{
244+
match (opt_s, opt_idx) {
245+
(Some(s), Some(idx)) => take_series(s.as_ref(), idx),
246+
_ => None,
247+
}
248+
}
249+
.transpose()
250+
})
251+
.collect::<PolarsResult<ListChunked>>()?;
252+
out.rename(list_ca.name());
253+
254+
Ok(out.into_series())
255+
}
256+
UInt32 | UInt64 => index_typed_index(idx),
257+
dt if dt.is_signed() => {
258+
if let Some(min) = idx.min::<i64>() {
259+
if min > 0 {
260+
let idx = idx.cast(&IDX_DTYPE).unwrap();
261+
index_typed_index(&idx)
262+
} else {
263+
let mut out = list_ca
264+
.amortized_iter()
265+
.map(|opt_s| {
266+
opt_s
267+
.and_then(|s| take_series(s.as_ref(), idx.clone()))
268+
.transpose()
269+
})
270+
.collect::<PolarsResult<ListChunked>>()?;
271+
out.rename(list_ca.name());
272+
Ok(out.into_series())
273+
}
274+
} else {
275+
Err(PolarsError::ComputeError("All indices are null".into()))
276+
}
277+
}
278+
dt => Err(PolarsError::ComputeError(
279+
format!("Cannot use dtype: '{dt}' as index.").into(),
280+
)),
281+
}
282+
}
283+
216284
fn lst_concat(&self, other: &[Series]) -> PolarsResult<ListChunked> {
217285
let ca = self.as_list();
218286
let other_len = other.len();
@@ -360,3 +428,57 @@ pub trait ListNameSpaceImpl: AsList {
360428
}
361429

362430
impl ListNameSpaceImpl for ListChunked {}
431+
432+
#[cfg(feature = "list_take")]
433+
fn take_series(s: &Series, idx: Series) -> Option<PolarsResult<Series>> {
434+
let len = s.len();
435+
let idx = cast_index(idx, len);
436+
let idx = idx.idx().unwrap();
437+
Some(s.take(idx))
438+
}
439+
440+
#[cfg(feature = "list_take")]
441+
fn cast_index_ca<T: PolarsNumericType>(idx: &ChunkedArray<T>, len: usize) -> Series
442+
where
443+
T::Native: Copy + PartialOrd + PartialEq + NumCast + Signed + Zero,
444+
{
445+
idx.into_iter()
446+
.map(|opt_idx| opt_idx.and_then(|idx| idx.negative_to_usize(len).map(|idx| idx as IdxSize)))
447+
.collect::<IdxCa>()
448+
.into_series()
449+
}
450+
451+
#[cfg(feature = "list_take")]
452+
fn cast_index(idx: Series, len: usize) -> Series {
453+
use DataType::*;
454+
match idx.dtype() {
455+
#[cfg(feature = "big_idx")]
456+
UInt32 => idx.cast(&IDX_DTYPE).unwrap(),
457+
#[cfg(feature = "big_idx")]
458+
UInt64 => idx,
459+
#[cfg(not(feature = "big_idx"))]
460+
UInt64 => idx.cast(&IDX_DTYPE).unwrap(),
461+
#[cfg(not(feature = "big_idx"))]
462+
UInt32 => idx,
463+
dt if dt.is_unsigned() => idx.cast(&IDX_DTYPE).unwrap(),
464+
Int8 => {
465+
let a = idx.i8().unwrap();
466+
cast_index_ca(a, len)
467+
}
468+
Int16 => {
469+
let a = idx.i16().unwrap();
470+
cast_index_ca(a, len)
471+
}
472+
Int32 => {
473+
let a = idx.i32().unwrap();
474+
cast_index_ca(a, len)
475+
}
476+
Int64 => {
477+
let a = idx.i64().unwrap();
478+
cast_index_ca(a, len)
479+
}
480+
_ => {
481+
unreachable!()
482+
}
483+
}
484+
}

0 commit comments

Comments
 (0)