Skip to content

fix: jax reducers returning incorrect output values or lengths #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c5d85e3
add tests and use outlength
ikrommyd Apr 14, 2025
1f74feb
rename test and skip if jax is not installed
ikrommyd Apr 14, 2025
4215ef9
count and argmax should be returning integers
ikrommyd Apr 14, 2025
7e681af
maybe make ak.any cpu and jax behavior match?
ikrommyd Apr 14, 2025
51a0945
skip product checks with negative values
ikrommyd Apr 14, 2025
9874810
segment product that deals with negative and zero values too
ikrommyd Apr 14, 2025
28432d6
simplify the function
ikrommyd Apr 14, 2025
6423929
my first attempt at apply_positional_corrections for jax
ikrommyd Apr 14, 2025
74a788f
Merge branch 'main' into ikrommyd/jax-backend-reducers
ikrommyd Apr 14, 2025
bfae901
.at is better than concatenate
ikrommyd Apr 14, 2025
32cf4e9
fix cases like ak.any([[0, 0], [1, 0], [2, 3, 4]]) maybe?
ikrommyd Apr 14, 2025
c35c7d4
fix sum and count with bools?
ikrommyd Apr 14, 2025
1892261
Merge branch 'main' into ikrommyd/jax-backend-reducers
ianna Apr 14, 2025
8178f4e
Merge branch 'main' into ikrommyd/jax-backend-reducers
ikrommyd Apr 15, 2025
5bc2976
Merge branch 'main' into ikrommyd/jax-backend-reducers
ikrommyd Apr 16, 2025
bdfa953
make jax reducers return the same dtype as the numpy ones at the cost…
ikrommyd Apr 21, 2025
8338000
Merge branch 'main' into ikrommyd/jax-backend-reducers
ikrommyd Apr 21, 2025
df106d7
add more tests
ikrommyd Apr 27, 2025
ace8829
Merge branch 'main' into ikrommyd/jax-backend-reducers
ikrommyd Apr 27, 2025
64b0a9b
Merge branch 'main' into ikrommyd/jax-backend-reducers
ianna Apr 27, 2025
7e70aa9
Merge branch 'main' into ikrommyd/jax-backend-reducers
ianna Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 170 additions & 32 deletions src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,78 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self:
raise NotImplementedError


def segment_argmin(data, segment_ids):
def awkward_JAXNumpyArray_reduce_adjust_starts_64(toptr, outlength, parents, starts):
if outlength == 0 or parents.size == 0:
return toptr

identity = jax.numpy.iinfo(jax.numpy.int64).max
valid = toptr[:outlength] != identity
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0)
parent_indices = parents[safe_sub_toptr]
adjustments = starts[parent_indices]
updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength])

return toptr.at[:outlength].set(updated)


def awkward_JAXNumpyArray_reduce_adjust_starts_shifts_64(
toptr, outlength, parents, starts, shifts
):
if outlength == 0 or parents.size == 0:
return toptr

identity = jax.numpy.iinfo(jax.numpy.int64).max
valid = toptr[:outlength] != identity
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0)
parent_indices = parents[safe_sub_toptr]
delta = shifts[safe_sub_toptr] - starts[parent_indices]
updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength])

return toptr.at[:outlength].set(updated)


def apply_positional_corrections(
reduced: ak.contents.NumpyArray,
parents: ak.index.Index,
starts: ak.index.Index,
shifts: ak.index.Index | None,
) -> ak._nplikes.ArrayLike:
if shifts is None:
assert (
parents.nplike is reduced.backend.nplike
and starts.nplike is reduced.backend.nplike
)
return awkward_JAXNumpyArray_reduce_adjust_starts_64(
reduced.data, reduced.length, parents.data, starts.data
)

else:
assert (
parents.nplike is reduced.backend.nplike
and starts.nplike is reduced.backend.nplike
and shifts.nplike is reduced.backend.nplike
)
return awkward_JAXNumpyArray_reduce_adjust_starts_shifts_64(
reduced.data,
reduced.length,
parents.data,
starts.data,
shifts.data,
)


def segment_argmin(data, segment_ids, num_segments):
"""
Applies a segmented argmin-style reduction.

Parameters:
data: jax.numpy.ndarray — the values to reduce.
segment_ids: same shape as data — indicates segment groupings.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — indices of min within each segment.
"""
num_segments = int(jax.numpy.max(segment_ids).item()) + 1
indices = jax.numpy.arange(data.shape[0])

# Find the minimum value in each segment
Expand Down Expand Up @@ -88,24 +148,27 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmin(array.data, parents.data)
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
result = segment_argmin(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)
result_array = ak.contents.NumpyArray(result, backend=array.backend)
corrected_data = apply_positional_corrections(
result_array, parents, starts, shifts
)
return ak.contents.NumpyArray(corrected_data, backend=array.backend)


def segment_argmax(data, segment_ids):
def segment_argmax(data, segment_ids, num_segments):
"""
Applies a segmented argmax-style reduction.

Parameters:
data: jax.numpy.ndarray — the values to reduce.
segment_ids: same shape as data — indicates segment groupings.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — indices of max within each segment.
"""
num_segments = int(jax.numpy.max(segment_ids).item()) + 1
indices = jax.numpy.arange(data.shape[0])

# Find the maximum value in each segment
Expand Down Expand Up @@ -145,16 +208,19 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmax(array.data, parents.data)
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
result = segment_argmax(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)
result_array = ak.contents.NumpyArray(result, backend=array.backend)
corrected_data = apply_positional_corrections(
result_array, parents, starts, shifts
)
return ak.contents.NumpyArray(corrected_data, backend=array.backend)


@overloads(_reducers.Count)
class Count(JAXReducer):
name: Final = "count"
preferred_dtype: Final = np.float64
preferred_dtype: Final = np.int64
needs_position: Final = False

@classmethod
Expand All @@ -175,8 +241,9 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.numpy.ones_like(array.data, dtype=array.dtype)
result = jax.ops.segment_sum(result, parents.data)
result = jax.numpy.ones_like(array.data, dtype=self.preferred_dtype)
result = jax.ops.segment_sum(result, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
Expand All @@ -186,21 +253,18 @@ def apply(
return ak.contents.NumpyArray(result, backend=array.backend)


def segment_count_nonzero(data, segment_ids, num_segments=None):
def segment_count_nonzero(data, segment_ids, num_segments):
"""
Counts the number of non-zero elements in `data` per segment.

Parameters:
data: jax.numpy.ndarray — input values to count.
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment.
num_segments: int (optional) — total number of segments.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — count of non-zero values per segment.
"""
if num_segments is None:
num_segments = int(jax.numpy.max(segment_ids).item()) + 1

# Create a binary mask where non-zero entries become 1
nonzero_mask = jax.numpy.where(data != 0, 1, 0)

Expand All @@ -211,7 +275,7 @@ def segment_count_nonzero(data, segment_ids, num_segments=None):
@overloads(_reducers.CountNonzero)
class CountNonzero(JAXReducer):
name: Final = "count_nonzero"
preferred_dtype: Final = np.float64
preferred_dtype: Final = np.int64
needs_position: Final = False

@classmethod
Expand All @@ -232,7 +296,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_count_nonzero(array.data, parents.data)
result = segment_count_nonzero(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -261,7 +325,11 @@ def apply(
if array.dtype.kind == "M":
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}")

result = jax.ops.segment_sum(array.data, parents.data)
if array.dtype.kind == "b":
input_array = array.data.astype(np.int64)
else:
input_array = array.data
result = jax.ops.segment_sum(input_array, parents.data, outlength)

if array.dtype.kind == "m":
return ak.contents.NumpyArray(
Expand All @@ -273,10 +341,60 @@ def apply(
return ak.contents.NumpyArray(result, backend=array.backend)


def segment_prod_with_negatives(data, segment_ids, num_segments):
"""
Computes the product of elements in each segment, handling negatives and booleans.
Parameters:
data: jax.numpy.ndarray — input values to reduce.
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment.
num_segments: int — total number of segments.
Returns:
jax.numpy.ndarray — product of values per segment.
"""
# Handle boolean arrays
if data.dtype == jax.numpy.bool_:
return jax.ops.segment_min(
data.astype(jax.numpy.int32), segment_ids, num_segments
)

# For numeric arrays, handle negative values and zeros
# Track zeros to set product to zero if any segment has zeros
is_zero = data == 0
has_zeros = (
jax.ops.segment_sum(is_zero.astype(jax.numpy.int32), segment_ids, num_segments)
> 0
)

# Track signs to determine final sign of product
is_negative = data < 0
neg_count = jax.ops.segment_sum(
is_negative.astype(jax.numpy.int32), segment_ids, num_segments
)
sign_products = 1 - 2 * (
neg_count % 2
) # +1 for even negatives, -1 for odd negatives

# Calculate product of absolute values in log space
log_abs = jax.numpy.log(jax.numpy.where(is_zero, 1.0, jax.numpy.abs(data)))
log_products = jax.ops.segment_sum(
jax.numpy.where(is_zero, 0.0, log_abs), segment_ids, num_segments
)
abs_products = jax.numpy.exp(log_products)

# Apply zeros and signs
product = jax.numpy.where(has_zeros, 0.0, sign_products * abs_products)
# floating point accuracy doesn't let us directly cast to integers
if np.issubdtype(data.dtype, np.integer):
result = jax.numpy.round(product).astype(data.dtype)
else:
result = product
return result


@overloads(_reducers.Prod)
class Prod(JAXReducer):
name: Final = "prod"
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
needs_position: Final = False

@classmethod
Expand All @@ -294,9 +412,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)
result = segment_prod_with_negatives(array.data, parents.data, outlength)

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
Expand All @@ -321,6 +437,25 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self:
def _return_dtype(cls, given_dtype):
return np.bool_

@staticmethod
def _max_initial(initial, type):
if initial is None:
if type in (
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
):
return np.iinfo(type).min
else:
return -np.inf

return initial

def apply(
self,
array: ak.contents.NumpyArray,
Expand All @@ -330,7 +465,10 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_max(array.data, parents.data)
result = jax.ops.segment_max(array.data, parents.data, outlength)
if array.dtype is not np.dtype(bool):
result = result.at[result == 0].set(self._max_initial(None, array.dtype))
result = result > self._max_initial(None, array.dtype)
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -360,7 +498,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -413,7 +551,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(array.data, parents.data, outlength)
result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype))

if np.issubdtype(array.dtype, np.complexfloating):
Expand Down Expand Up @@ -474,9 +612,9 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_max(array.data, parents.data)

result = jax.ops.segment_max(array.data, parents.data, outlength)
result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype))

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
array.backend.nplike.asarray(
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def apply(

class CountNonzero(KernelReducer):
name: Final = "count_nonzero"
preferred_dtype: Final = np.float64
preferred_dtype: Final = np.int64
needs_position: Final = False

def apply(
Expand Down Expand Up @@ -437,7 +437,7 @@ def apply(

class Prod(KernelReducer):
name: Final = "prod"
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
needs_position: Final = False

def apply(
Expand Down
Loading
Loading