Skip to content

Commit 0556905

Browse files
raulcdpitrou
andauthored
GH-45380: [Python] Expose RankQuantileOptions to Python (#45392)
### Rationale for this change `RankQuantileOptions` are currently not exposed on Pyarrow and CI job breaks when `-W error` is used. ### What changes are included in this PR? Expose `RankQuantileOptions` and test options and kernel from pyarrow. It also includes some minor refactor for the unwrap sort keys logic to move it into a common function. ### Are these changes tested? Yes ### Are there any user-facing changes? The options for the new kernel are exposed on pyarrow. * GitHub Issue: #45380 Lead-authored-by: Raúl Cumplido <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 8fed34e commit 0556905

File tree

7 files changed

+94
-33
lines changed

7 files changed

+94
-33
lines changed

cpp/src/arrow/compute/api_vector.cc

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
270270
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
271271
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
272272
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
273+
DCHECK_OK(registry->AddFunctionOptionsType(kRankQuantileOptionsType));
273274
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
274275
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
275276
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));

python/pyarrow/_acero.pyx

+3-10
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table,
3030
from pyarrow.lib import frombytes, tobytes
3131
from pyarrow._compute cimport (
3232
Expression, FunctionOptions, _ensure_field_ref, _true,
33-
unwrap_null_placement, unwrap_sort_order
33+
unwrap_null_placement, unwrap_sort_keys
3434
)
3535

3636

@@ -234,17 +234,10 @@ class AggregateNodeOptions(_AggregateNodeOptions):
234234
cdef class _OrderByNodeOptions(ExecNodeOptions):
235235

236236
def _set_options(self, sort_keys, null_placement):
237-
cdef:
238-
vector[CSortKey] c_sort_keys
239-
240-
for name, order in sort_keys:
241-
c_sort_keys.push_back(
242-
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
243-
)
244-
245237
self.wrapped.reset(
246238
new COrderByNodeOptions(
247-
COrdering(c_sort_keys, unwrap_null_placement(null_placement))
239+
COrdering(unwrap_sort_keys(sort_keys, allow_str=False),
240+
unwrap_null_placement(null_placement))
248241
)
249242
)
250243

python/pyarrow/_compute.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ cdef CExpression _true
6565

6666
cdef CFieldRef _ensure_field_ref(value) except *
6767

68+
cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=*) except *
69+
6870
cdef CSortOrder unwrap_sort_order(order) except *
6971

7072
cdef CNullPlacement unwrap_null_placement(null_placement) except *

python/pyarrow/_compute.pyx

+49-23
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ def _forbid_instantiation(klass, subclasses_instead=True):
7474
raise TypeError(msg)
7575

7676

77+
cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True):
78+
cdef vector[CSortKey] c_sort_keys
79+
if allow_str and isinstance(sort_keys, str):
80+
c_sort_keys.push_back(
81+
CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys))
82+
)
83+
else:
84+
for name, order in sort_keys:
85+
c_sort_keys.push_back(
86+
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
87+
)
88+
return c_sort_keys
89+
90+
7791
cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
7892
"""
7993
Wrap a C++ scalar Function in a ScalarFunction object.
@@ -2093,13 +2107,9 @@ class ArraySortOptions(_ArraySortOptions):
20932107

20942108
cdef class _SortOptions(FunctionOptions):
20952109
def _set_options(self, sort_keys, null_placement):
2096-
cdef vector[CSortKey] c_sort_keys
2097-
for name, order in sort_keys:
2098-
c_sort_keys.push_back(
2099-
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
2100-
)
21012110
self.wrapped.reset(new CSortOptions(
2102-
c_sort_keys, unwrap_null_placement(null_placement)))
2111+
unwrap_sort_keys(sort_keys, allow_str=False),
2112+
unwrap_null_placement(null_placement)))
21032113

21042114

21052115
class SortOptions(_SortOptions):
@@ -2125,12 +2135,7 @@ class SortOptions(_SortOptions):
21252135

21262136
cdef class _SelectKOptions(FunctionOptions):
21272137
def _set_options(self, k, sort_keys):
2128-
cdef vector[CSortKey] c_sort_keys
2129-
for name, order in sort_keys:
2130-
c_sort_keys.push_back(
2131-
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
2132-
)
2133-
self.wrapped.reset(new CSelectKOptions(k, c_sort_keys))
2138+
self.wrapped.reset(new CSelectKOptions(k, unwrap_sort_keys(sort_keys, allow_str=False)))
21342139

21352140

21362141
class SelectKOptions(_SelectKOptions):
@@ -2317,19 +2322,9 @@ cdef class _RankOptions(FunctionOptions):
23172322
}
23182323

23192324
def _set_options(self, sort_keys, null_placement, tiebreaker):
2320-
cdef vector[CSortKey] c_sort_keys
2321-
if isinstance(sort_keys, str):
2322-
c_sort_keys.push_back(
2323-
CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys))
2324-
)
2325-
else:
2326-
for name, order in sort_keys:
2327-
c_sort_keys.push_back(
2328-
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
2329-
)
23302325
try:
23312326
self.wrapped.reset(
2332-
new CRankOptions(c_sort_keys,
2327+
new CRankOptions(unwrap_sort_keys(sort_keys),
23332328
unwrap_null_placement(null_placement),
23342329
self._tiebreaker_map[tiebreaker])
23352330
)
@@ -2370,6 +2365,37 @@ class RankOptions(_RankOptions):
23702365
self._set_options(sort_keys, null_placement, tiebreaker)
23712366

23722367

2368+
cdef class _RankQuantileOptions(FunctionOptions):
2369+
2370+
def _set_options(self, sort_keys, null_placement):
2371+
self.wrapped.reset(
2372+
new CRankQuantileOptions(unwrap_sort_keys(sort_keys),
2373+
unwrap_null_placement(null_placement))
2374+
)
2375+
2376+
2377+
class RankQuantileOptions(_RankQuantileOptions):
2378+
"""
2379+
Options for the `rank_quantile` function.
2380+
2381+
Parameters
2382+
----------
2383+
sort_keys : sequence of (name, order) tuples or str, default "ascending"
2384+
Names of field/column keys to sort the input on,
2385+
along with the order each field/column is sorted in.
2386+
Accepted values for `order` are "ascending", "descending".
2387+
The field name can be a string column name or expression.
2388+
Alternatively, one can simply pass "ascending" or "descending" as a string
2389+
if the input is array-like.
2390+
null_placement : str, default "at_end"
2391+
Where nulls in input should be sorted.
2392+
Accepted values are "at_start", "at_end".
2393+
"""
2394+
2395+
def __init__(self, sort_keys="ascending", *, null_placement="at_end"):
2396+
self._set_options(sort_keys, null_placement)
2397+
2398+
23732399
cdef class Expression(_Weakrefable):
23742400
"""
23752401
A logical expression to be evaluated against some input.

python/pyarrow/compute.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
QuantileOptions,
5757
RandomOptions,
5858
RankOptions,
59+
RankQuantileOptions,
5960
ReplaceSliceOptions,
6061
ReplaceSubstringOptions,
6162
RoundBinaryOptions,

python/pyarrow/includes/libarrow.pxd

+6
Original file line numberDiff line numberDiff line change
@@ -2788,6 +2788,12 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
27882788
CNullPlacement null_placement
27892789
CRankOptionsTiebreaker tiebreaker
27902790

2791+
cdef cppclass CRankQuantileOptions \
2792+
"arrow::compute::RankQuantileOptions"(CFunctionOptions):
2793+
CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement)
2794+
vector[CSortKey] sort_keys
2795+
CNullPlacement null_placement
2796+
27912797
cdef enum DatumType" arrow::Datum::type":
27922798
DatumType_NONE" arrow::Datum::NONE"
27932799
DatumType_SCALAR" arrow::Datum::SCALAR"

python/pyarrow/tests/test_compute.py

+32
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def test_option_class_equality(request):
172172
pc.RandomOptions(),
173173
pc.RankOptions(sort_keys="ascending",
174174
null_placement="at_start", tiebreaker="max"),
175+
pc.RankQuantileOptions(sort_keys="ascending",
176+
null_placement="at_start"),
175177
pc.ReplaceSliceOptions(0, 1, "a"),
176178
pc.ReplaceSubstringOptions("a", "b"),
177179
pc.RoundOptions(2, "towards_infinity"),
@@ -3360,6 +3362,36 @@ def test_rank_options():
33603362
tiebreaker="NonExisting")
33613363

33623364

3365+
def test_rank_quantile_options():
3366+
arr = pa.array([None, 1, None, 2, None])
3367+
expected = pa.array([0.7, 0.1, 0.7, 0.3, 0.7], type=pa.float64())
3368+
3369+
# Ensure rank_quantile can be called without specifying options
3370+
result = pc.rank_quantile(arr)
3371+
assert result.equals(expected)
3372+
3373+
# Ensure default RankOptions
3374+
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions())
3375+
assert result.equals(expected)
3376+
3377+
# Ensure sort_keys tuple usage
3378+
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions(
3379+
sort_keys=[("b", "ascending")])
3380+
)
3381+
assert result.equals(expected)
3382+
3383+
result = pc.rank_quantile(arr, null_placement="at_start")
3384+
expected_at_start = pa.array([0.3, 0.7, 0.3, 0.9, 0.3], type=pa.float64())
3385+
assert result.equals(expected_at_start)
3386+
3387+
result = pc.rank_quantile(arr, sort_keys="descending")
3388+
expected_descending = pa.array([0.7, 0.3, 0.7, 0.1, 0.7], type=pa.float64())
3389+
assert result.equals(expected_descending)
3390+
3391+
with pytest.raises(ValueError, match="not a valid sort order"):
3392+
pc.rank_quantile(arr, sort_keys="XXX")
3393+
3394+
33633395
def create_sample_expressions():
33643396
# We need a schema for substrait conversion
33653397
schema = pa.schema([pa.field("i64", pa.int64()), pa.field(

0 commit comments

Comments
 (0)