Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 83 additions & 7 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,7 +1810,7 @@ def searchsorted(
Value(s) to insert into `self`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `self`).
sorter : 1-D array-like, optional
Optional array of integer indices that sort array a into ascending
Expand All @@ -1837,19 +1837,95 @@ def searchsorted(
"searchsorted requires array to be sorted, which is impossible "
"with NAs present."
)
if isinstance(value, ExtensionArray):
value = value.astype(object)
# Base class searchsorted would cast to object, which is *much* slower.

if sorter is not None:
dtype = None

if isinstance(self.dtype, ArrowDtype):
pa_dtype = self.dtype.pyarrow_dtype

if (
pa.types.is_timestamp(pa_dtype) or pa.types.is_duration(pa_dtype)
) and pa_dtype.unit == "ns":
dtype = object

return self.to_numpy(dtype=dtype).searchsorted(
value,
side=side,
sorter=sorter,
)

arr = self._pa_array.combine_chunks()
pa_dtype = arr.type

# Fast Arrow-native path for strings
if pa.types.is_string(pa_dtype) or pa.types.is_large_string(pa_dtype):
offsets_buf = arr.buffers()[1]
data_buf = arr.buffers()[2]

offset_dtype = np.int64 if pa.types.is_large_string(pa_dtype) else np.int32

offsets = np.frombuffer(offsets_buf, dtype=offset_dtype)
data = memoryview(data_buf)

def get_string(i: int) -> bytes:
start = offsets[i]
end = offsets[i + 1]
return data[start:end].tobytes()

def binary_search(target, side_local):
if isinstance(target, str):
target = target.encode()
elif not isinstance(target, bytes):
target = str(target).encode()

left = 0
right = len(arr)

while left < right:
mid = (left + right) // 2
mid_value = get_string(mid)

if side_local == "left":
if mid_value < target:
left = mid + 1
else:
right = mid
elif mid_value <= target:
left = mid + 1
else:
right = mid

return left

# scalar input
if is_scalar(value):
return np.intp(binary_search(value, side))

# vector input
result = np.empty(len(value), dtype=np.intp)

for i, val in enumerate(value):
result[i] = binary_search(val, side)

return result

# Fallback for non-string dtypes
dtype = None

if isinstance(self.dtype, ArrowDtype):
pa_dtype = self.dtype.pyarrow_dtype

if (
pa.types.is_timestamp(pa_dtype) or pa.types.is_duration(pa_dtype)
) and pa_dtype.unit == "ns":
# np.array[datetime/timedelta].searchsorted(datetime/timedelta)
# erroneously fails when numpy type resolution is nanoseconds
dtype = object
return self.to_numpy(dtype=dtype).searchsorted(value, side=side, sorter=sorter)

return self.to_numpy(dtype=dtype).searchsorted(
value,
side=side,
sorter=sorter,
)

def take(
self,
Expand Down
Loading