Skip to content

fix: jax reducers returning incorrect output values or lengths #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 168 additions & 31 deletions src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,83 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self:
raise NotImplementedError


def segment_argmin(data, segment_ids):
def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts):
if outlength == 0 or parents.size == 0:
return toptr

identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype)
valid = toptr[:outlength] != identity
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0)
safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64)
parent_indices = parents[safe_sub_toptr_int]
adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)]
updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength])

return toptr.at[:outlength].set(updated)


def awkward_JAXArray_reduce_adjust_starts_shifts_64(
toptr, outlength, parents, starts, shifts
):
if outlength == 0 or parents.size == 0:
return toptr

identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype)
valid = toptr[:outlength] != identity
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0)
safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64)
parent_indices = parents[safe_sub_toptr_int]
delta = (
shifts[safe_sub_toptr_int]
- starts[jax.numpy.astype(parent_indices, jax.numpy.int64)]
)
updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength])

return toptr.at[:outlength].set(updated)


def apply_positional_corrections(
reduced: ak.contents.NumpyArray,
parents: ak.index.Index,
starts: ak.index.Index,
shifts: ak.index.Index | None,
) -> ak._nplikes.ArrayLike:
if shifts is None:
assert (
parents.nplike is reduced.backend.nplike
and starts.nplike is reduced.backend.nplike
)
return awkward_JAXArray_reduce_adjust_starts_64(
reduced.data, reduced.length, parents.data, starts.data
)

else:
assert (
parents.nplike is reduced.backend.nplike
and starts.nplike is reduced.backend.nplike
and shifts.nplike is reduced.backend.nplike
)
return awkward_JAXArray_reduce_adjust_starts_shifts_64(
reduced.data,
reduced.length,
parents.data,
starts.data,
shifts.data,
)


def segment_argmin(data, segment_ids, num_segments):
"""
Applies a segmented argmin-style reduction.

Parameters:
data: jax.numpy.ndarray — the values to reduce.
segment_ids: same shape as data — indicates segment groupings.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — indices of min within each segment.
"""
num_segments = int(jax.numpy.max(segment_ids).item()) + 1
indices = jax.numpy.arange(data.shape[0])

# Find the minimum value in each segment
Expand All @@ -68,7 +133,7 @@ def segment_argmin(data, segment_ids):
class ArgMin(JAXReducer):
name: Final = "argmin"
needs_position: Final = True
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argmin returns the index (i.e. position) of the minimum value. Indices are always integers, not floating-point numbers.


@classmethod
def from_kernel_reducer(cls, reducer: Reducer) -> Self:
Expand All @@ -88,24 +153,27 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmin(array.data, parents.data)
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
result = segment_argmin(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)
result_array = ak.contents.NumpyArray(result, backend=array.backend)
corrected_data = apply_positional_corrections(
result_array, parents, starts, shifts
)
return ak.contents.NumpyArray(corrected_data, backend=array.backend)


def segment_argmax(data, segment_ids):
def segment_argmax(data, segment_ids, num_segments):
"""
Applies a segmented argmax-style reduction.

Parameters:
data: jax.numpy.ndarray — the values to reduce.
segment_ids: same shape as data — indicates segment groupings.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — indices of max within each segment.
"""
num_segments = int(jax.numpy.max(segment_ids).item()) + 1
indices = jax.numpy.arange(data.shape[0])

# Find the maximum value in each segment
Expand All @@ -125,7 +193,7 @@ def segment_argmax(data, segment_ids):
class ArgMax(JAXReducer):
name: Final = "argmax"
needs_position: Final = True
preferred_dtype: Final = np.int64
preferred_dtype: Final = np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same argument here - indices are always integers, not floating-point numbers.


@classmethod
def from_kernel_reducer(cls, reducer: Reducer) -> Self:
Expand All @@ -145,10 +213,13 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmax(array.data, parents.data)
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
result = segment_argmax(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)
result_array = ak.contents.NumpyArray(result, backend=array.backend)
corrected_data = apply_positional_corrections(
result_array, parents, starts, shifts
)
return ak.contents.NumpyArray(corrected_data, backend=array.backend)


@overloads(_reducers.Count)
Expand All @@ -175,8 +246,9 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.numpy.ones_like(array.data, dtype=array.dtype)
result = jax.ops.segment_sum(result, parents.data)
result = jax.numpy.ones_like(array.data, dtype=self.preferred_dtype)
result = jax.ops.segment_sum(result, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
Expand All @@ -186,21 +258,18 @@ def apply(
return ak.contents.NumpyArray(result, backend=array.backend)


def segment_count_nonzero(data, segment_ids, num_segments=None):
def segment_count_nonzero(data, segment_ids, num_segments):
"""
Counts the number of non-zero elements in `data` per segment.

Parameters:
data: jax.numpy.ndarray — input values to count.
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment.
num_segments: int (optional) — total number of segments.
num_segments: int — total number of segments.

Returns:
jax.numpy.ndarray — count of non-zero values per segment.
"""
if num_segments is None:
num_segments = int(jax.numpy.max(segment_ids).item()) + 1

# Create a binary mask where non-zero entries become 1
nonzero_mask = jax.numpy.where(data != 0, 1, 0)

Expand Down Expand Up @@ -232,7 +301,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_count_nonzero(array.data, parents.data)
result = segment_count_nonzero(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -261,7 +330,11 @@ def apply(
if array.dtype.kind == "M":
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - if we want to allow sum of boolean, I think, we should view the data as integers here:

    if array.dtype == np.bool_:
        data = array.data.astype(jax.numpy.int32)
    else:
        data = array.data

result = jax.ops.segment_sum(array.data, parents.data)
if array.dtype.kind == "b":
input_array = array.data.astype(np.int64)
else:
input_array = array.data
result = jax.ops.segment_sum(input_array, parents.data, outlength)

if array.dtype.kind == "m":
return ak.contents.NumpyArray(
Expand All @@ -273,6 +346,50 @@ def apply(
return ak.contents.NumpyArray(result, backend=array.backend)


def segment_prod_with_negatives(data, segment_ids, num_segments):
"""
Computes the product of elements in each segment, handling negatives and booleans.
Parameters:
data: jax.numpy.ndarray — input values to reduce.
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment.
num_segments: int — total number of segments.
Returns:
jax.numpy.ndarray — product of values per segment.
"""
# Handle boolean arrays
if data.dtype == jax.numpy.bool_:
return jax.ops.segment_min(
data.astype(jax.numpy.int32), segment_ids, num_segments
)

# For numeric arrays, handle negative values and zeros
# Track zeros to set product to zero if any segment has zeros
is_zero = data == 0
has_zeros = (
jax.ops.segment_sum(is_zero.astype(jax.numpy.int32), segment_ids, num_segments)
> 0
)

# Track signs to determine final sign of product
is_negative = data < 0
neg_count = jax.ops.segment_sum(
is_negative.astype(jax.numpy.int32), segment_ids, num_segments
)
sign_products = 1 - 2 * (
neg_count % 2
) # +1 for even negatives, -1 for odd negatives

# Calculate product of absolute values in log space
log_abs = jax.numpy.log(jax.numpy.where(is_zero, 1.0, jax.numpy.abs(data)))
log_products = jax.ops.segment_sum(
jax.numpy.where(is_zero, 0.0, log_abs), segment_ids, num_segments
)
abs_products = jax.numpy.exp(log_products)

# Apply zeros and signs
return jax.numpy.where(has_zeros, 0.0, sign_products * abs_products)


@overloads(_reducers.Prod)
class Prod(JAXReducer):
name: Final = "prod"
Expand All @@ -294,9 +411,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)
result = segment_prod_with_negatives(array.data, parents.data, outlength)

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
Expand All @@ -321,6 +436,25 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self:
def _return_dtype(cls, given_dtype):
return np.bool_

@staticmethod
def _max_initial(initial, type):
if initial is None:
if type in (
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
):
return np.iinfo(type).min
else:
return -np.inf

return initial

def apply(
self,
array: ak.contents.NumpyArray,
Expand All @@ -330,7 +464,10 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_max(array.data, parents.data)
result = jax.ops.segment_max(array.data, parents.data, outlength)
if array.dtype is not np.dtype(bool):
result = result.at[result == 0].set(self._max_initial(None, array.dtype))
result = result > self._max_initial(None, array.dtype)
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -360,7 +497,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(array.data, parents.data, outlength)
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -413,7 +550,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(array.data, parents.data, outlength)
result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype))

if np.issubdtype(array.dtype, np.complexfloating):
Expand Down Expand Up @@ -474,9 +611,9 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_max(array.data, parents.data)

result = jax.ops.segment_max(array.data, parents.data, outlength)
result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype))

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
array.backend.nplike.asarray(
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def apply(

class CountNonzero(KernelReducer):
name: Final = "count_nonzero"
preferred_dtype: Final = np.float64
preferred_dtype: Final = np.int64
needs_position: Final = False

def apply(
Expand Down
Loading
Loading