Skip to content

Commit bfbc34d

Browse files
feat: consistently return Python scalars from Series reductions for PyArrow (#1471)
--------- Co-authored-by: Francesco Bruzzesi <[email protected]>
1 parent 635434e commit bfbc34d

19 files changed

+137
-98
lines changed

narwhals/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,21 @@
8686
"Field",
8787
"Float32",
8888
"Float64",
89+
"Int8",
8990
"Int16",
9091
"Int32",
9192
"Int64",
92-
"Int8",
9393
"LazyFrame",
9494
"List",
9595
"Object",
9696
"Schema",
9797
"Series",
9898
"String",
9999
"Struct",
100+
"UInt8",
100101
"UInt16",
101102
"UInt32",
102103
"UInt64",
103-
"UInt8",
104104
"Unknown",
105105
"all",
106106
"all_horizontal",
@@ -113,8 +113,8 @@
113113
"exceptions",
114114
"from_arrow",
115115
"from_dict",
116-
"from_numpy",
117116
"from_native",
117+
"from_numpy",
118118
"generate_temporary_column_name",
119119
"get_level",
120120
"get_native_namespace",

narwhals/_arrow/dataframe.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ def is_empty(self: Self) -> bool:
541541
return self.shape[0] == 0
542542

543543
def item(self: Self, row: int | None, column: int | str | None) -> Any:
544+
from narwhals._arrow.series import maybe_extract_py_scalar
545+
544546
if row is None and column is None:
545547
if self.shape != (1, 1):
546548
msg = (
@@ -549,14 +551,18 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any:
549551
f" frame has shape {self.shape!r}"
550552
)
551553
raise ValueError(msg)
552-
return self._native_frame[0][0]
554+
return maybe_extract_py_scalar(
555+
self._native_frame[0][0], return_py_scalar=True
556+
)
553557

554558
elif row is None or column is None:
555559
msg = "cannot call `.item()` with only one of `row` or `column`"
556560
raise ValueError(msg)
557561

558562
_col = self.columns.index(column) if isinstance(column, str) else column
559-
return self._native_frame[_col][row]
563+
return maybe_extract_py_scalar(
564+
self._native_frame[_col][row], return_py_scalar=True
565+
)
560566

561567
def rename(self: Self, mapping: dict[str, str]) -> Self:
562568
df = self._native_frame

narwhals/_arrow/expr.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626

2727
class ArrowExpr:
28+
_implementation: Implementation = Implementation.PYARROW
29+
2830
def __init__(
2931
self: Self,
3032
call: Callable[[ArrowDataFrame], list[ArrowSeries]],

narwhals/_arrow/series.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from narwhals._arrow.utils import native_to_narwhals_dtype
1515
from narwhals._arrow.utils import parse_datetime_format
1616
from narwhals._arrow.utils import validate_column_comparand
17-
from narwhals.translate import to_py_scalar
1817
from narwhals.utils import Implementation
1918
from narwhals.utils import generate_temporary_column_name
2019

@@ -32,6 +31,12 @@
3231
from narwhals.typing import DTypes
3332

3433

34+
def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001
35+
if return_py_scalar:
36+
return getattr(value, "as_py", lambda: value)()
37+
return value
38+
39+
3540
class ArrowSeries:
3641
def __init__(
3742
self: Self,
@@ -241,8 +246,8 @@ def __invert__(self: Self) -> Self:
241246

242247
return self._from_native_series(pc.invert(self._native_series))
243248

244-
def len(self: Self) -> int:
245-
return len(self._native_series)
249+
def len(self: Self, *, _return_py_scalar: bool = True) -> int:
250+
return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
246251

247252
def filter(self: Self, other: Any) -> Self:
248253
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
@@ -251,12 +256,12 @@ def filter(self: Self, other: Any) -> Self:
251256
ser = self._native_series
252257
return self._from_native_series(ser.filter(other))
253258

254-
def mean(self: Self) -> int:
259+
def mean(self: Self, *, _return_py_scalar: bool = True) -> int:
255260
import pyarrow.compute as pc # ignore-banned-import()
256261

257-
return pc.mean(self._native_series) # type: ignore[no-any-return]
262+
return maybe_extract_py_scalar(pc.mean(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
258263

259-
def median(self: Self) -> int:
264+
def median(self: Self, *, _return_py_scalar: bool = True) -> int:
260265
import pyarrow.compute as pc # ignore-banned-import()
261266

262267
from narwhals.exceptions import InvalidOperationError
@@ -265,22 +270,24 @@ def median(self: Self) -> int:
265270
msg = "`median` operation not supported for non-numeric input type."
266271
raise InvalidOperationError(msg)
267272

268-
return pc.approximate_median(self._native_series) # type: ignore[no-any-return]
273+
return maybe_extract_py_scalar( # type: ignore[no-any-return]
274+
pc.approximate_median(self._native_series), _return_py_scalar
275+
)
269276

270-
def min(self: Self) -> int:
277+
def min(self: Self, *, _return_py_scalar: bool = True) -> int:
271278
import pyarrow.compute as pc # ignore-banned-import()
272279

273-
return pc.min(self._native_series) # type: ignore[no-any-return]
280+
return maybe_extract_py_scalar(pc.min(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
274281

275-
def max(self: Self) -> int:
282+
def max(self: Self, *, _return_py_scalar: bool = True) -> int:
276283
import pyarrow.compute as pc # ignore-banned-import()
277284

278-
return pc.max(self._native_series) # type: ignore[no-any-return]
285+
return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
279286

280-
def sum(self: Self) -> int:
287+
def sum(self: Self, *, _return_py_scalar: bool = True) -> int:
281288
import pyarrow.compute as pc # ignore-banned-import()
282289

283-
return pc.sum(self._native_series) # type: ignore[no-any-return]
290+
return maybe_extract_py_scalar(pc.sum(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
284291

285292
def drop_nulls(self: Self) -> ArrowSeries:
286293
import pyarrow.compute as pc # ignore-banned-import()
@@ -300,12 +307,14 @@ def shift(self: Self, n: int) -> Self:
300307
result = ca
301308
return self._from_native_series(result)
302309

303-
def std(self: Self, ddof: int) -> float:
310+
def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float:
304311
import pyarrow.compute as pc # ignore-banned-import()
305312

306-
return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return]
313+
return maybe_extract_py_scalar( # type: ignore[no-any-return]
314+
pc.stddev(self._native_series, ddof=ddof), _return_py_scalar
315+
)
307316

308-
def skew(self: Self) -> float | None:
317+
def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None:
309318
import pyarrow.compute as pc # ignore-banned-import()
310319

311320
ser = self._native_series
@@ -321,18 +330,22 @@ def skew(self: Self) -> float | None:
321330
m2 = pc.mean(pc.power(m, 2))
322331
m3 = pc.mean(pc.power(m, 3))
323332
# Biased population skewness
324-
return pc.divide(m3, pc.power(m2, 1.5)) # type: ignore[no-any-return]
333+
return maybe_extract_py_scalar( # type: ignore[no-any-return]
334+
pc.divide(m3, pc.power(m2, 1.5)), _return_py_scalar
335+
)
325336

326-
def count(self: Self) -> int:
337+
def count(self: Self, *, _return_py_scalar: bool = True) -> int:
327338
import pyarrow.compute as pc # ignore-banned-import()
328339

329-
return pc.count(self._native_series) # type: ignore[no-any-return]
340+
return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
330341

331-
def n_unique(self: Self) -> int:
342+
def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int:
332343
import pyarrow.compute as pc # ignore-banned-import()
333344

334345
unique_values = pc.unique(self._native_series)
335-
return pc.count(unique_values, mode="all") # type: ignore[no-any-return]
346+
return maybe_extract_py_scalar( # type: ignore[no-any-return]
347+
pc.count(unique_values, mode="all"), _return_py_scalar
348+
)
336349

337350
def __native_namespace__(self: Self) -> ModuleType:
338351
if self._implementation is Implementation.PYARROW:
@@ -430,15 +443,15 @@ def diff(self: Self) -> Self:
430443
pc.pairwise_diff(self._native_series.combine_chunks())
431444
)
432445

433-
def any(self: Self) -> bool:
446+
def any(self: Self, *, _return_py_scalar: bool = True) -> bool:
434447
import pyarrow.compute as pc # ignore-banned-import()
435448

436-
return to_py_scalar(pc.any(self._native_series)) # type: ignore[no-any-return]
449+
return maybe_extract_py_scalar(pc.any(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
437450

438-
def all(self: Self) -> bool:
451+
def all(self: Self, *, _return_py_scalar: bool = True) -> bool:
439452
import pyarrow.compute as pc # ignore-banned-import()
440453

441-
return to_py_scalar(pc.all(self._native_series)) # type: ignore[no-any-return]
454+
return maybe_extract_py_scalar(pc.all(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
442455

443456
def is_between(
444457
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
@@ -480,8 +493,8 @@ def cast(self: Self, dtype: DType) -> Self:
480493
dtype = narwhals_to_native_dtype(dtype, self._dtypes)
481494
return self._from_native_series(pc.cast(ser, dtype))
482495

483-
def null_count(self: Self) -> int:
484-
return self._native_series.null_count # type: ignore[no-any-return]
496+
def null_count(self: Self, *, _return_py_scalar: bool = True) -> int:
497+
return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) # type: ignore[no-any-return]
485498

486499
def head(self: Self, n: int) -> Self:
487500
ser = self._native_series
@@ -527,8 +540,8 @@ def item(self: Self, index: int | None = None) -> Any:
527540
f" or an explicit index is provided (Series is of length {len(self)})"
528541
)
529542
raise ValueError(msg)
530-
return self._native_series[0]
531-
return self._native_series[index]
543+
return maybe_extract_py_scalar(self._native_series[0], return_py_scalar=True)
544+
return maybe_extract_py_scalar(self._native_series[index], return_py_scalar=True)
532545

533546
def value_counts(
534547
self: Self,
@@ -718,7 +731,7 @@ def is_sorted(self: Self, *, descending: bool) -> bool:
718731
result = pc.all(pc.greater_equal(ser[:-1], ser[1:]))
719732
else:
720733
result = pc.all(pc.less_equal(ser[:-1], ser[1:]))
721-
return to_py_scalar(result) # type: ignore[no-any-return]
734+
return maybe_extract_py_scalar(result, return_py_scalar=True) # type: ignore[no-any-return]
722735

723736
def unique(self: Self, *, maintain_order: bool) -> ArrowSeries:
724737
# The param `maintain_order` is only here for compatibility with the Polars API
@@ -798,12 +811,15 @@ def quantile(
798811
self: Self,
799812
quantile: float,
800813
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
814+
*,
815+
_return_py_scalar: bool = True,
801816
) -> Any:
802817
import pyarrow.compute as pc # ignore-banned-import()
803818

804-
return pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[
805-
0
806-
]
819+
return maybe_extract_py_scalar(
820+
pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0],
821+
_return_py_scalar,
822+
)
807823

808824
def gather_every(self: Self, n: int, offset: int = 0) -> Self:
809825
return self._from_native_series(self._native_series[offset::n])
@@ -994,7 +1010,10 @@ def rolling_mean(
9941010
return result
9951011

9961012
def __iter__(self: Self) -> Iterator[Any]:
997-
yield from self._native_series.__iter__()
1013+
yield from (
1014+
maybe_extract_py_scalar(x, return_py_scalar=True)
1015+
for x in self._native_series.__iter__()
1016+
)
9981017

9991018
@property
10001019
def shape(self: Self) -> tuple[int]:

narwhals/_dask/expr.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030

3131
class DaskExpr:
32+
_implementation: Implementation = Implementation.DASK
33+
3234
def __init__(
3335
self,
3436
call: Callable[[DaskLazyFrame], list[dask_expr.Series]],

narwhals/_expression_parsing.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from narwhals.dependencies import is_numpy_array
1616
from narwhals.exceptions import InvalidIntoExprError
17+
from narwhals.utils import Implementation
1718

1819
if TYPE_CHECKING:
1920
from narwhals._arrow.dataframe import ArrowDataFrame
@@ -223,9 +224,17 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]:
223224
for arg_name, arg_value in kwargs.items()
224225
}
225226

227+
# For PyArrow.Series, we return Python Scalars (like Polars does) instead of PyArrow Scalars.
228+
# However, when working with expressions, we keep everything PyArrow-native.
229+
extra_kwargs = (
230+
{"_return_py_scalar": False}
231+
if returns_scalar and expr._implementation is Implementation.PYARROW
232+
else {}
233+
)
234+
226235
out: list[CompliantSeries] = [
227236
plx._create_series_from_scalar(
228-
getattr(series, attr)(*_args, **_kwargs),
237+
getattr(series, attr)(*_args, **extra_kwargs, **_kwargs),
229238
reference_series=series, # type: ignore[arg-type]
230239
)
231240
if returns_scalar

narwhals/dependencies.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -302,30 +302,30 @@ def is_into_dataframe(native_dataframe: Any) -> bool:
302302

303303

304304
__all__ = [
305-
"get_polars",
306-
"get_pandas",
307-
"get_modin",
308305
"get_cudf",
309-
"get_pyarrow",
310-
"get_numpy",
311306
"get_ibis",
307+
"get_modin",
308+
"get_numpy",
309+
"get_pandas",
310+
"get_polars",
311+
"get_pyarrow",
312+
"is_cudf_dataframe",
313+
"is_cudf_series",
314+
"is_dask_dataframe",
312315
"is_ibis_table",
316+
"is_into_dataframe",
317+
"is_into_series",
318+
"is_modin_dataframe",
319+
"is_modin_series",
320+
"is_numpy_array",
313321
"is_pandas_dataframe",
314-
"is_pandas_series",
315322
"is_pandas_index",
323+
"is_pandas_like_dataframe",
324+
"is_pandas_like_series",
325+
"is_pandas_series",
316326
"is_polars_dataframe",
317327
"is_polars_lazyframe",
318328
"is_polars_series",
319-
"is_modin_dataframe",
320-
"is_modin_series",
321-
"is_cudf_dataframe",
322-
"is_cudf_series",
323-
"is_pyarrow_table",
324329
"is_pyarrow_chunked_array",
325-
"is_numpy_array",
326-
"is_dask_dataframe",
327-
"is_pandas_like_dataframe",
328-
"is_pandas_like_series",
329-
"is_into_dataframe",
330-
"is_into_series",
330+
"is_pyarrow_table",
331331
]

narwhals/selectors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,10 @@ def all() -> Expr:
271271

272272

273273
__all__ = [
274+
"all",
275+
"boolean",
274276
"by_dtype",
277+
"categorical",
275278
"numeric",
276-
"boolean",
277279
"string",
278-
"categorical",
279-
"all",
280280
]

narwhals/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def median(self) -> Any:
656656
>>> my_library_agnostic_function(s_pl)
657657
5.0
658658
>>> my_library_agnostic_function(s_pa)
659-
<pyarrow.DoubleScalar: 5.0>
659+
5.0
660660
"""
661661
return self._compliant_series.median()
662662

0 commit comments

Comments
 (0)