14
14
from narwhals ._arrow .utils import native_to_narwhals_dtype
15
15
from narwhals ._arrow .utils import parse_datetime_format
16
16
from narwhals ._arrow .utils import validate_column_comparand
17
- from narwhals .translate import to_py_scalar
18
17
from narwhals .utils import Implementation
19
18
from narwhals .utils import generate_temporary_column_name
20
19
32
31
from narwhals .typing import DTypes
33
32
34
33
34
+ def maybe_extract_py_scalar (value : Any , return_py_scalar : bool ) -> Any : # noqa: FBT001
35
+ if return_py_scalar :
36
+ return getattr (value , "as_py" , lambda : value )()
37
+ return value
38
+
39
+
35
40
class ArrowSeries :
36
41
def __init__ (
37
42
self : Self ,
@@ -241,8 +246,8 @@ def __invert__(self: Self) -> Self:
241
246
242
247
return self ._from_native_series (pc .invert (self ._native_series ))
243
248
244
- def len (self : Self ) -> int :
245
- return len (self ._native_series )
249
+ def len (self : Self , * , _return_py_scalar : bool = True ) -> int :
250
+ return maybe_extract_py_scalar ( len (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
246
251
247
252
def filter (self : Self , other : Any ) -> Self :
248
253
if not (isinstance (other , list ) and all (isinstance (x , bool ) for x in other )):
@@ -251,12 +256,12 @@ def filter(self: Self, other: Any) -> Self:
251
256
ser = self ._native_series
252
257
return self ._from_native_series (ser .filter (other ))
253
258
254
- def mean (self : Self ) -> int :
259
+ def mean (self : Self , * , _return_py_scalar : bool = True ) -> int :
255
260
import pyarrow .compute as pc # ignore-banned-import()
256
261
257
- return pc .mean (self ._native_series ) # type: ignore[no-any-return]
262
+ return maybe_extract_py_scalar ( pc .mean (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
258
263
259
- def median (self : Self ) -> int :
264
+ def median (self : Self , * , _return_py_scalar : bool = True ) -> int :
260
265
import pyarrow .compute as pc # ignore-banned-import()
261
266
262
267
from narwhals .exceptions import InvalidOperationError
@@ -265,22 +270,24 @@ def median(self: Self) -> int:
265
270
msg = "`median` operation not supported for non-numeric input type."
266
271
raise InvalidOperationError (msg )
267
272
268
- return pc .approximate_median (self ._native_series ) # type: ignore[no-any-return]
273
+ return maybe_extract_py_scalar ( # type: ignore[no-any-return]
274
+ pc .approximate_median (self ._native_series ), _return_py_scalar
275
+ )
269
276
270
- def min (self : Self ) -> int :
277
+ def min (self : Self , * , _return_py_scalar : bool = True ) -> int :
271
278
import pyarrow .compute as pc # ignore-banned-import()
272
279
273
- return pc .min (self ._native_series ) # type: ignore[no-any-return]
280
+ return maybe_extract_py_scalar ( pc .min (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
274
281
275
- def max (self : Self ) -> int :
282
+ def max (self : Self , * , _return_py_scalar : bool = True ) -> int :
276
283
import pyarrow .compute as pc # ignore-banned-import()
277
284
278
- return pc .max (self ._native_series ) # type: ignore[no-any-return]
285
+ return maybe_extract_py_scalar ( pc .max (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
279
286
280
- def sum (self : Self ) -> int :
287
+ def sum (self : Self , * , _return_py_scalar : bool = True ) -> int :
281
288
import pyarrow .compute as pc # ignore-banned-import()
282
289
283
- return pc .sum (self ._native_series ) # type: ignore[no-any-return]
290
+ return maybe_extract_py_scalar ( pc .sum (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
284
291
285
292
def drop_nulls (self : Self ) -> ArrowSeries :
286
293
import pyarrow .compute as pc # ignore-banned-import()
@@ -300,12 +307,14 @@ def shift(self: Self, n: int) -> Self:
300
307
result = ca
301
308
return self ._from_native_series (result )
302
309
303
- def std (self : Self , ddof : int ) -> float :
310
+ def std (self : Self , ddof : int , * , _return_py_scalar : bool = True ) -> float :
304
311
import pyarrow .compute as pc # ignore-banned-import()
305
312
306
- return pc .stddev (self ._native_series , ddof = ddof ) # type: ignore[no-any-return]
313
+ return maybe_extract_py_scalar ( # type: ignore[no-any-return]
314
+ pc .stddev (self ._native_series , ddof = ddof ), _return_py_scalar
315
+ )
307
316
308
- def skew (self : Self ) -> float | None :
317
+ def skew (self : Self , * , _return_py_scalar : bool = True ) -> float | None :
309
318
import pyarrow .compute as pc # ignore-banned-import()
310
319
311
320
ser = self ._native_series
@@ -321,18 +330,22 @@ def skew(self: Self) -> float | None:
321
330
m2 = pc .mean (pc .power (m , 2 ))
322
331
m3 = pc .mean (pc .power (m , 3 ))
323
332
# Biased population skewness
324
- return pc .divide (m3 , pc .power (m2 , 1.5 )) # type: ignore[no-any-return]
333
+ return maybe_extract_py_scalar ( # type: ignore[no-any-return]
334
+ pc .divide (m3 , pc .power (m2 , 1.5 )), _return_py_scalar
335
+ )
325
336
326
- def count (self : Self ) -> int :
337
+ def count (self : Self , * , _return_py_scalar : bool = True ) -> int :
327
338
import pyarrow .compute as pc # ignore-banned-import()
328
339
329
- return pc .count (self ._native_series ) # type: ignore[no-any-return]
340
+ return maybe_extract_py_scalar ( pc .count (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
330
341
331
- def n_unique (self : Self ) -> int :
342
+ def n_unique (self : Self , * , _return_py_scalar : bool = True ) -> int :
332
343
import pyarrow .compute as pc # ignore-banned-import()
333
344
334
345
unique_values = pc .unique (self ._native_series )
335
- return pc .count (unique_values , mode = "all" ) # type: ignore[no-any-return]
346
+ return maybe_extract_py_scalar ( # type: ignore[no-any-return]
347
+ pc .count (unique_values , mode = "all" ), _return_py_scalar
348
+ )
336
349
337
350
def __native_namespace__ (self : Self ) -> ModuleType :
338
351
if self ._implementation is Implementation .PYARROW :
@@ -430,15 +443,15 @@ def diff(self: Self) -> Self:
430
443
pc .pairwise_diff (self ._native_series .combine_chunks ())
431
444
)
432
445
433
- def any (self : Self ) -> bool :
446
+ def any (self : Self , * , _return_py_scalar : bool = True ) -> bool :
434
447
import pyarrow .compute as pc # ignore-banned-import()
435
448
436
- return to_py_scalar (pc .any (self ._native_series )) # type: ignore[no-any-return]
449
+ return maybe_extract_py_scalar (pc .any (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
437
450
438
- def all (self : Self ) -> bool :
451
+ def all (self : Self , * , _return_py_scalar : bool = True ) -> bool :
439
452
import pyarrow .compute as pc # ignore-banned-import()
440
453
441
- return to_py_scalar (pc .all (self ._native_series )) # type: ignore[no-any-return]
454
+ return maybe_extract_py_scalar (pc .all (self ._native_series ), _return_py_scalar ) # type: ignore[no-any-return]
442
455
443
456
def is_between (
444
457
self , lower_bound : Any , upper_bound : Any , closed : str = "both"
@@ -480,8 +493,8 @@ def cast(self: Self, dtype: DType) -> Self:
480
493
dtype = narwhals_to_native_dtype (dtype , self ._dtypes )
481
494
return self ._from_native_series (pc .cast (ser , dtype ))
482
495
483
- def null_count (self : Self ) -> int :
484
- return self ._native_series .null_count # type: ignore[no-any-return]
496
+ def null_count (self : Self , * , _return_py_scalar : bool = True ) -> int :
497
+ return maybe_extract_py_scalar ( self ._native_series .null_count , _return_py_scalar ) # type: ignore[no-any-return]
485
498
486
499
def head (self : Self , n : int ) -> Self :
487
500
ser = self ._native_series
@@ -527,8 +540,8 @@ def item(self: Self, index: int | None = None) -> Any:
527
540
f" or an explicit index is provided (Series is of length { len (self )} )"
528
541
)
529
542
raise ValueError (msg )
530
- return self ._native_series [0 ]
531
- return self ._native_series [index ]
543
+ return maybe_extract_py_scalar ( self ._native_series [0 ], return_py_scalar = True )
544
+ return maybe_extract_py_scalar ( self ._native_series [index ], return_py_scalar = True )
532
545
533
546
def value_counts (
534
547
self : Self ,
@@ -718,7 +731,7 @@ def is_sorted(self: Self, *, descending: bool) -> bool:
718
731
result = pc .all (pc .greater_equal (ser [:- 1 ], ser [1 :]))
719
732
else :
720
733
result = pc .all (pc .less_equal (ser [:- 1 ], ser [1 :]))
721
- return to_py_scalar (result ) # type: ignore[no-any-return]
734
+ return maybe_extract_py_scalar (result , return_py_scalar = True ) # type: ignore[no-any-return]
722
735
723
736
def unique (self : Self , * , maintain_order : bool ) -> ArrowSeries :
724
737
# The param `maintain_order` is only here for compatibility with the Polars API
@@ -798,12 +811,15 @@ def quantile(
798
811
self : Self ,
799
812
quantile : float ,
800
813
interpolation : Literal ["nearest" , "higher" , "lower" , "midpoint" , "linear" ],
814
+ * ,
815
+ _return_py_scalar : bool = True ,
801
816
) -> Any :
802
817
import pyarrow .compute as pc # ignore-banned-import()
803
818
804
- return pc .quantile (self ._native_series , q = quantile , interpolation = interpolation )[
805
- 0
806
- ]
819
+ return maybe_extract_py_scalar (
820
+ pc .quantile (self ._native_series , q = quantile , interpolation = interpolation )[0 ],
821
+ _return_py_scalar ,
822
+ )
807
823
808
824
def gather_every (self : Self , n : int , offset : int = 0 ) -> Self :
809
825
return self ._from_native_series (self ._native_series [offset ::n ])
@@ -994,7 +1010,10 @@ def rolling_mean(
994
1010
return result
995
1011
996
1012
def __iter__ (self : Self ) -> Iterator [Any ]:
997
- yield from self ._native_series .__iter__ ()
1013
+ yield from (
1014
+ maybe_extract_py_scalar (x , return_py_scalar = True )
1015
+ for x in self ._native_series .__iter__ ()
1016
+ )
998
1017
999
1018
@property
1000
1019
def shape (self : Self ) -> tuple [int ]:
0 commit comments