@@ -4,6 +4,8 @@ use std::fmt::Write;
44use polars_arrow:: kernels:: list:: sublist_get;
55use polars_arrow:: prelude:: ValueSize ;
66use polars_core:: chunked_array:: builder:: get_list_builder;
7+ #[ cfg( feature = "list_take" ) ]
8+ use polars_core:: export:: num:: { NumCast , Signed , Zero } ;
79#[ cfg( feature = "diff" ) ]
810use polars_core:: series:: ops:: NullBehavior ;
911use polars_core:: utils:: { try_get_supertype, CustomIterTools } ;
@@ -213,6 +215,72 @@ pub trait ListNameSpaceImpl: AsList {
213215 Series :: try_from ( ( ca. name ( ) , chunks) )
214216 }
215217
218+ #[ cfg( feature = "list_take" ) ]
219+ fn lst_take ( & self , idx : & Series ) -> PolarsResult < Series > {
220+ let list_ca = self . as_list ( ) ;
221+
222+ let index_typed_index = |idx : & Series | {
223+ let other = idx. cast ( & IDX_DTYPE ) . unwrap ( ) ;
224+ let idx = other. idx ( ) . unwrap ( ) ;
225+ list_ca
226+ . amortized_iter ( )
227+ . map ( |s| s. map ( |s| s. as_ref ( ) . take ( idx) ) . transpose ( ) )
228+ . collect :: < PolarsResult < ListChunked > > ( )
229+ . map ( |mut ca| {
230+ ca. rename ( list_ca. name ( ) ) ;
231+ ca. into_series ( )
232+ } )
233+ } ;
234+
235+ use DataType :: * ;
236+ match idx. dtype ( ) {
237+ List ( _) => {
238+ let idx_ca = idx. list ( ) . unwrap ( ) ;
239+ let mut out = list_ca
240+ . amortized_iter ( )
241+ . zip ( idx_ca. into_iter ( ) )
242+ . map ( |( opt_s, opt_idx) | {
243+ {
244+ match ( opt_s, opt_idx) {
245+ ( Some ( s) , Some ( idx) ) => take_series ( s. as_ref ( ) , idx) ,
246+ _ => None ,
247+ }
248+ }
249+ . transpose ( )
250+ } )
251+ . collect :: < PolarsResult < ListChunked > > ( ) ?;
252+ out. rename ( list_ca. name ( ) ) ;
253+
254+ Ok ( out. into_series ( ) )
255+ }
256+ UInt32 | UInt64 => index_typed_index ( idx) ,
257+ dt if dt. is_signed ( ) => {
258+ if let Some ( min) = idx. min :: < i64 > ( ) {
259+ if min > 0 {
260+ let idx = idx. cast ( & IDX_DTYPE ) . unwrap ( ) ;
261+ index_typed_index ( & idx)
262+ } else {
263+ let mut out = list_ca
264+ . amortized_iter ( )
265+ . map ( |opt_s| {
266+ opt_s
267+ . and_then ( |s| take_series ( s. as_ref ( ) , idx. clone ( ) ) )
268+ . transpose ( )
269+ } )
270+ . collect :: < PolarsResult < ListChunked > > ( ) ?;
271+ out. rename ( list_ca. name ( ) ) ;
272+ Ok ( out. into_series ( ) )
273+ }
274+ } else {
275+ Err ( PolarsError :: ComputeError ( "All indices are null" . into ( ) ) )
276+ }
277+ }
278+ dt => Err ( PolarsError :: ComputeError (
279+ format ! ( "Cannot use dtype: '{dt}' as index." ) . into ( ) ,
280+ ) ) ,
281+ }
282+ }
283+
216284 fn lst_concat ( & self , other : & [ Series ] ) -> PolarsResult < ListChunked > {
217285 let ca = self . as_list ( ) ;
218286 let other_len = other. len ( ) ;
@@ -360,3 +428,57 @@ pub trait ListNameSpaceImpl: AsList {
360428}
361429
362430impl ListNameSpaceImpl for ListChunked { }
431+
432+ #[ cfg( feature = "list_take" ) ]
433+ fn take_series ( s : & Series , idx : Series ) -> Option < PolarsResult < Series > > {
434+ let len = s. len ( ) ;
435+ let idx = cast_index ( idx, len) ;
436+ let idx = idx. idx ( ) . unwrap ( ) ;
437+ Some ( s. take ( idx) )
438+ }
439+
440+ #[ cfg( feature = "list_take" ) ]
441+ fn cast_index_ca < T : PolarsNumericType > ( idx : & ChunkedArray < T > , len : usize ) -> Series
442+ where
443+ T :: Native : Copy + PartialOrd + PartialEq + NumCast + Signed + Zero ,
444+ {
445+ idx. into_iter ( )
446+ . map ( |opt_idx| opt_idx. and_then ( |idx| idx. negative_to_usize ( len) . map ( |idx| idx as IdxSize ) ) )
447+ . collect :: < IdxCa > ( )
448+ . into_series ( )
449+ }
450+
451+ #[ cfg( feature = "list_take" ) ]
452+ fn cast_index ( idx : Series , len : usize ) -> Series {
453+ use DataType :: * ;
454+ match idx. dtype ( ) {
455+ #[ cfg( feature = "big_idx" ) ]
456+ UInt32 => idx. cast ( & IDX_DTYPE ) . unwrap ( ) ,
457+ #[ cfg( feature = "big_idx" ) ]
458+ UInt64 => idx,
459+ #[ cfg( not( feature = "big_idx" ) ) ]
460+ UInt64 => idx. cast ( & IDX_DTYPE ) . unwrap ( ) ,
461+ #[ cfg( not( feature = "big_idx" ) ) ]
462+ UInt32 => idx,
463+ dt if dt. is_unsigned ( ) => idx. cast ( & IDX_DTYPE ) . unwrap ( ) ,
464+ Int8 => {
465+ let a = idx. i8 ( ) . unwrap ( ) ;
466+ cast_index_ca ( a, len)
467+ }
468+ Int16 => {
469+ let a = idx. i16 ( ) . unwrap ( ) ;
470+ cast_index_ca ( a, len)
471+ }
472+ Int32 => {
473+ let a = idx. i32 ( ) . unwrap ( ) ;
474+ cast_index_ca ( a, len)
475+ }
476+ Int64 => {
477+ let a = idx. i64 ( ) . unwrap ( ) ;
478+ cast_index_ca ( a, len)
479+ }
480+ _ => {
481+ unreachable ! ( )
482+ }
483+ }
484+ }
0 commit comments