-
Notifications
You must be signed in to change notification settings - Fork 90
fix: jax reducers returning incorrect output values or lengths #3464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c5d85e3
1f74feb
4215ef9
7e681af
51a0945
9874810
28432d6
6423929
74a788f
bfae901
32cf4e9
c35c7d4
1892261
8178f4e
5bc2976
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,18 +37,83 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: | |
raise NotImplementedError | ||
|
||
|
||
def segment_argmin(data, segment_ids): | ||
def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): | ||
if outlength == 0 or parents.size == 0: | ||
return toptr | ||
|
||
identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) | ||
valid = toptr[:outlength] != identity | ||
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) | ||
safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) | ||
parent_indices = parents[safe_sub_toptr_int] | ||
adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] | ||
updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength]) | ||
|
||
return toptr.at[:outlength].set(updated) | ||
|
||
|
||
def awkward_JAXArray_reduce_adjust_starts_shifts_64( | ||
toptr, outlength, parents, starts, shifts | ||
): | ||
if outlength == 0 or parents.size == 0: | ||
return toptr | ||
|
||
identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) | ||
valid = toptr[:outlength] != identity | ||
safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) | ||
safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) | ||
parent_indices = parents[safe_sub_toptr_int] | ||
delta = ( | ||
shifts[safe_sub_toptr_int] | ||
- starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] | ||
) | ||
updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength]) | ||
|
||
return toptr.at[:outlength].set(updated) | ||
|
||
|
||
def apply_positional_corrections( | ||
reduced: ak.contents.NumpyArray, | ||
parents: ak.index.Index, | ||
starts: ak.index.Index, | ||
shifts: ak.index.Index | None, | ||
) -> ak._nplikes.ArrayLike: | ||
if shifts is None: | ||
assert ( | ||
parents.nplike is reduced.backend.nplike | ||
and starts.nplike is reduced.backend.nplike | ||
) | ||
return awkward_JAXArray_reduce_adjust_starts_64( | ||
reduced.data, reduced.length, parents.data, starts.data | ||
) | ||
|
||
else: | ||
assert ( | ||
parents.nplike is reduced.backend.nplike | ||
and starts.nplike is reduced.backend.nplike | ||
and shifts.nplike is reduced.backend.nplike | ||
) | ||
return awkward_JAXArray_reduce_adjust_starts_shifts_64( | ||
reduced.data, | ||
reduced.length, | ||
parents.data, | ||
starts.data, | ||
shifts.data, | ||
) | ||
|
||
|
||
def segment_argmin(data, segment_ids, num_segments): | ||
""" | ||
Applies a segmented argmin-style reduction. | ||
|
||
Parameters: | ||
data: jax.numpy.ndarray — the values to reduce. | ||
segment_ids: same shape as data — indicates segment groupings. | ||
num_segments: int — total number of segments. | ||
|
||
Returns: | ||
jax.numpy.ndarray — indices of min within each segment. | ||
""" | ||
num_segments = int(jax.numpy.max(segment_ids).item()) + 1 | ||
indices = jax.numpy.arange(data.shape[0]) | ||
|
||
# Find the minimum value in each segment | ||
|
@@ -68,7 +133,7 @@ def segment_argmin(data, segment_ids): | |
class ArgMin(JAXReducer): | ||
name: Final = "argmin" | ||
needs_position: Final = True | ||
preferred_dtype: Final = np.int64 | ||
preferred_dtype: Final = np.float64 | ||
|
||
@classmethod | ||
def from_kernel_reducer(cls, reducer: Reducer) -> Self: | ||
|
@@ -88,24 +153,27 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = segment_argmin(array.data, parents.data) | ||
result = jax.numpy.asarray(result, dtype=array.dtype) | ||
|
||
return ak.contents.NumpyArray(result, backend=array.backend) | ||
result = segment_argmin(array.data, parents.data, outlength) | ||
result = jax.numpy.asarray(result, dtype=self.preferred_dtype) | ||
result_array = ak.contents.NumpyArray(result, backend=array.backend) | ||
corrected_data = apply_positional_corrections( | ||
result_array, parents, starts, shifts | ||
) | ||
return ak.contents.NumpyArray(corrected_data, backend=array.backend) | ||
|
||
|
||
def segment_argmax(data, segment_ids): | ||
def segment_argmax(data, segment_ids, num_segments): | ||
""" | ||
Applies a segmented argmax-style reduction. | ||
|
||
Parameters: | ||
data: jax.numpy.ndarray — the values to reduce. | ||
segment_ids: same shape as data — indicates segment groupings. | ||
num_segments: int — total number of segments. | ||
|
||
Returns: | ||
jax.numpy.ndarray — indices of max within each segment. | ||
""" | ||
num_segments = int(jax.numpy.max(segment_ids).item()) + 1 | ||
indices = jax.numpy.arange(data.shape[0]) | ||
|
||
# Find the maximum value in each segment | ||
|
@@ -125,7 +193,7 @@ def segment_argmax(data, segment_ids): | |
class ArgMax(JAXReducer): | ||
name: Final = "argmax" | ||
needs_position: Final = True | ||
preferred_dtype: Final = np.int64 | ||
preferred_dtype: Final = np.float64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the same argument here - indices are always integers, not floating-point numbers. |
||
|
||
@classmethod | ||
def from_kernel_reducer(cls, reducer: Reducer) -> Self: | ||
|
@@ -145,10 +213,13 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = segment_argmax(array.data, parents.data) | ||
result = jax.numpy.asarray(result, dtype=array.dtype) | ||
|
||
return ak.contents.NumpyArray(result, backend=array.backend) | ||
result = segment_argmax(array.data, parents.data, outlength) | ||
result = jax.numpy.asarray(result, dtype=self.preferred_dtype) | ||
result_array = ak.contents.NumpyArray(result, backend=array.backend) | ||
corrected_data = apply_positional_corrections( | ||
result_array, parents, starts, shifts | ||
) | ||
return ak.contents.NumpyArray(corrected_data, backend=array.backend) | ||
|
||
|
||
@overloads(_reducers.Count) | ||
|
@@ -175,8 +246,9 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = jax.numpy.ones_like(array.data, dtype=array.dtype) | ||
result = jax.ops.segment_sum(result, parents.data) | ||
result = jax.numpy.ones_like(array.data, dtype=self.preferred_dtype) | ||
result = jax.ops.segment_sum(result, parents.data, outlength) | ||
result = jax.numpy.asarray(result, dtype=self.preferred_dtype) | ||
|
||
if np.issubdtype(array.dtype, np.complexfloating): | ||
return ak.contents.NumpyArray( | ||
|
@@ -186,21 +258,18 @@ def apply( | |
return ak.contents.NumpyArray(result, backend=array.backend) | ||
|
||
|
||
def segment_count_nonzero(data, segment_ids, num_segments=None): | ||
def segment_count_nonzero(data, segment_ids, num_segments): | ||
""" | ||
Counts the number of non-zero elements in `data` per segment. | ||
|
||
Parameters: | ||
data: jax.numpy.ndarray — input values to count. | ||
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. | ||
num_segments: int (optional) — total number of segments. | ||
num_segments: int — total number of segments. | ||
|
||
Returns: | ||
jax.numpy.ndarray — count of non-zero values per segment. | ||
""" | ||
if num_segments is None: | ||
num_segments = int(jax.numpy.max(segment_ids).item()) + 1 | ||
|
||
# Create a binary mask where non-zero entries become 1 | ||
nonzero_mask = jax.numpy.where(data != 0, 1, 0) | ||
|
||
|
@@ -232,7 +301,7 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = segment_count_nonzero(array.data, parents.data) | ||
result = segment_count_nonzero(array.data, parents.data, outlength) | ||
result = jax.numpy.asarray(result, dtype=self.preferred_dtype) | ||
|
||
return ak.contents.NumpyArray(result, backend=array.backend) | ||
|
@@ -261,7 +330,11 @@ def apply( | |
if array.dtype.kind == "M": | ||
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ikrommyd - if we want to allow sum of boolean, I think, we should view the data as integers here: if array.dtype == np.bool_:
data = array.data.astype(jax.numpy.int32)
else:
data = array.data |
||
result = jax.ops.segment_sum(array.data, parents.data) | ||
if array.dtype.kind == "b": | ||
input_array = array.data.astype(np.int64) | ||
else: | ||
input_array = array.data | ||
result = jax.ops.segment_sum(input_array, parents.data, outlength) | ||
|
||
if array.dtype.kind == "m": | ||
return ak.contents.NumpyArray( | ||
|
@@ -273,6 +346,50 @@ def apply( | |
return ak.contents.NumpyArray(result, backend=array.backend) | ||
|
||
|
||
def segment_prod_with_negatives(data, segment_ids, num_segments): | ||
""" | ||
Computes the product of elements in each segment, handling negatives and booleans. | ||
Parameters: | ||
data: jax.numpy.ndarray — input values to reduce. | ||
segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. | ||
num_segments: int — total number of segments. | ||
Returns: | ||
jax.numpy.ndarray — product of values per segment. | ||
""" | ||
# Handle boolean arrays | ||
if data.dtype == jax.numpy.bool_: | ||
return jax.ops.segment_min( | ||
data.astype(jax.numpy.int32), segment_ids, num_segments | ||
) | ||
|
||
# For numeric arrays, handle negative values and zeros | ||
# Track zeros to set product to zero if any segment has zeros | ||
is_zero = data == 0 | ||
has_zeros = ( | ||
jax.ops.segment_sum(is_zero.astype(jax.numpy.int32), segment_ids, num_segments) | ||
> 0 | ||
) | ||
|
||
# Track signs to determine final sign of product | ||
is_negative = data < 0 | ||
neg_count = jax.ops.segment_sum( | ||
is_negative.astype(jax.numpy.int32), segment_ids, num_segments | ||
) | ||
sign_products = 1 - 2 * ( | ||
neg_count % 2 | ||
) # +1 for even negatives, -1 for odd negatives | ||
|
||
# Calculate product of absolute values in log space | ||
log_abs = jax.numpy.log(jax.numpy.where(is_zero, 1.0, jax.numpy.abs(data))) | ||
log_products = jax.ops.segment_sum( | ||
jax.numpy.where(is_zero, 0.0, log_abs), segment_ids, num_segments | ||
) | ||
abs_products = jax.numpy.exp(log_products) | ||
|
||
# Apply zeros and signs | ||
return jax.numpy.where(has_zeros, 0.0, sign_products * abs_products) | ||
|
||
|
||
@overloads(_reducers.Prod) | ||
class Prod(JAXReducer): | ||
name: Final = "prod" | ||
|
@@ -294,9 +411,7 @@ def apply( | |
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
# See issue https://github.com/google/jax/issues/9296 | ||
result = jax.numpy.exp( | ||
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data) | ||
) | ||
result = segment_prod_with_negatives(array.data, parents.data, outlength) | ||
|
||
if np.issubdtype(array.dtype, np.complexfloating): | ||
return ak.contents.NumpyArray( | ||
|
@@ -321,6 +436,25 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: | |
def _return_dtype(cls, given_dtype): | ||
return np.bool_ | ||
|
||
@staticmethod | ||
def _max_initial(initial, type): | ||
if initial is None: | ||
if type in ( | ||
np.int8, | ||
np.int16, | ||
np.int32, | ||
np.int64, | ||
np.uint8, | ||
np.uint16, | ||
np.uint32, | ||
np.uint64, | ||
): | ||
return np.iinfo(type).min | ||
else: | ||
return -np.inf | ||
|
||
return initial | ||
|
||
def apply( | ||
self, | ||
array: ak.contents.NumpyArray, | ||
|
@@ -330,7 +464,10 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = jax.ops.segment_max(array.data, parents.data) | ||
result = jax.ops.segment_max(array.data, parents.data, outlength) | ||
if array.dtype is not np.dtype(bool): | ||
result = result.at[result == 0].set(self._max_initial(None, array.dtype)) | ||
result = result > self._max_initial(None, array.dtype) | ||
result = jax.numpy.asarray(result, dtype=bool) | ||
|
||
return ak.contents.NumpyArray(result, backend=array.backend) | ||
|
@@ -360,7 +497,7 @@ def apply( | |
outlength: ShapeItem, | ||
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
result = jax.ops.segment_min(array.data, parents.data) | ||
result = jax.ops.segment_min(array.data, parents.data, outlength) | ||
result = jax.numpy.asarray(result, dtype=bool) | ||
|
||
return ak.contents.NumpyArray(result, backend=array.backend) | ||
|
@@ -413,7 +550,7 @@ def apply( | |
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
|
||
result = jax.ops.segment_min(array.data, parents.data) | ||
result = jax.ops.segment_min(array.data, parents.data, outlength) | ||
result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype)) | ||
|
||
if np.issubdtype(array.dtype, np.complexfloating): | ||
|
@@ -474,9 +611,9 @@ def apply( | |
) -> ak.contents.NumpyArray: | ||
assert isinstance(array, ak.contents.NumpyArray) | ||
|
||
result = jax.ops.segment_max(array.data, parents.data) | ||
|
||
result = jax.ops.segment_max(array.data, parents.data, outlength) | ||
result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype)) | ||
|
||
if np.issubdtype(array.dtype, np.complexfloating): | ||
return ak.contents.NumpyArray( | ||
array.backend.nplike.asarray( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argmin
returns the index (i.e. position) of the minimum value. Indices are always integers, not floating-point numbers.