diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 4d9b0344c6..2a356e18e0 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -37,18 +37,83 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: raise NotImplementedError -def segment_argmin(data, segment_ids): +def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): + if outlength == 0 or parents.size == 0: + return toptr + + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + valid = toptr[:outlength] != identity + safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) + safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) + parent_indices = parents[safe_sub_toptr_int] + adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] + updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength]) + + return toptr.at[:outlength].set(updated) + + +def awkward_JAXArray_reduce_adjust_starts_shifts_64( + toptr, outlength, parents, starts, shifts +): + if outlength == 0 or parents.size == 0: + return toptr + + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + valid = toptr[:outlength] != identity + safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) + safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) + parent_indices = parents[safe_sub_toptr_int] + delta = ( + shifts[safe_sub_toptr_int] + - starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] + ) + updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength]) + + return toptr.at[:outlength].set(updated) + + +def apply_positional_corrections( + reduced: ak.contents.NumpyArray, + parents: ak.index.Index, + starts: ak.index.Index, + shifts: ak.index.Index | None, +) -> ak._nplikes.ArrayLike: + if shifts is None: + assert ( + parents.nplike is reduced.backend.nplike + and starts.nplike is reduced.backend.nplike + ) + return awkward_JAXArray_reduce_adjust_starts_64( + reduced.data, reduced.length, parents.data, starts.data + ) + + else: + assert ( + parents.nplike is reduced.backend.nplike + and starts.nplike is reduced.backend.nplike + and shifts.nplike is reduced.backend.nplike + ) + return awkward_JAXArray_reduce_adjust_starts_shifts_64( + reduced.data, + reduced.length, + parents.data, + starts.data, + shifts.data, + ) + + +def segment_argmin(data, segment_ids, num_segments): """ Applies a segmented argmin-style reduction. Parameters: data: jax.numpy.ndarray — the values to reduce. segment_ids: same shape as data — indicates segment groupings. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — indices of min within each segment. """ - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 indices = jax.numpy.arange(data.shape[0]) # Find the minimum value in each segment @@ -68,7 +133,7 @@ def segment_argmin(data, segment_ids): class ArgMin(JAXReducer): name: Final = "argmin" needs_position: Final = True - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -88,24 +153,27 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_argmin(array.data, parents.data) - result = jax.numpy.asarray(result, dtype=array.dtype) - - return ak.contents.NumpyArray(result, backend=array.backend) + result = segment_argmin(array.data, parents.data, outlength) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) + result_array = ak.contents.NumpyArray(result, backend=array.backend) + corrected_data = apply_positional_corrections( + result_array, parents, starts, shifts + ) + return ak.contents.NumpyArray(corrected_data, backend=array.backend) -def segment_argmax(data, segment_ids): +def segment_argmax(data, segment_ids, num_segments): """ Applies a segmented argmax-style reduction. Parameters: data: jax.numpy.ndarray — the values to reduce. segment_ids: same shape as data — indicates segment groupings. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — indices of max within each segment. """ - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 indices = jax.numpy.arange(data.shape[0]) # Find the maximum value in each segment @@ -125,7 +193,7 @@ def segment_argmax(data, segment_ids): class ArgMax(JAXReducer): name: Final = "argmax" needs_position: Final = True - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -145,10 +213,13 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_argmax(array.data, parents.data) - result = jax.numpy.asarray(result, dtype=array.dtype) - - return ak.contents.NumpyArray(result, backend=array.backend) + result = segment_argmax(array.data, parents.data, outlength) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) + result_array = ak.contents.NumpyArray(result, backend=array.backend) + corrected_data = apply_positional_corrections( + result_array, parents, starts, shifts + ) + return ak.contents.NumpyArray(corrected_data, backend=array.backend) @overloads(_reducers.Count) @@ -175,8 +246,9 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.numpy.ones_like(array.data, dtype=array.dtype) - result = jax.ops.segment_sum(result, parents.data) + result = jax.numpy.ones_like(array.data, dtype=self.preferred_dtype) + result = jax.ops.segment_sum(result, parents.data, outlength) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( @@ -186,21 +258,18 @@ def apply( return ak.contents.NumpyArray(result, backend=array.backend) -def segment_count_nonzero(data, segment_ids, num_segments=None): +def segment_count_nonzero(data, segment_ids, num_segments): """ Counts the number of non-zero elements in `data` per segment. Parameters: data: jax.numpy.ndarray — input values to count. segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. - num_segments: int (optional) — total number of segments. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — count of non-zero values per segment. """ - if num_segments is None: - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 - # Create a binary mask where non-zero entries become 1 nonzero_mask = jax.numpy.where(data != 0, 1, 0) @@ -232,7 +301,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_count_nonzero(array.data, parents.data) + result = segment_count_nonzero(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=self.preferred_dtype) return ak.contents.NumpyArray(result, backend=array.backend) @@ -261,7 +330,11 @@ def apply( if array.dtype.kind == "M": raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}") - result = jax.ops.segment_sum(array.data, parents.data) + if array.dtype.kind == "b": + input_array = array.data.astype(np.int64) + else: + input_array = array.data + result = jax.ops.segment_sum(input_array, parents.data, outlength) if array.dtype.kind == "m": return ak.contents.NumpyArray( @@ -273,6 +346,50 @@ def apply( return ak.contents.NumpyArray(result, backend=array.backend) +def segment_prod_with_negatives(data, segment_ids, num_segments): + """ + Computes the product of elements in each segment, handling negatives and booleans. + Parameters: + data: jax.numpy.ndarray — input values to reduce. + segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. + num_segments: int — total number of segments. + Returns: + jax.numpy.ndarray — product of values per segment. + """ + # Handle boolean arrays + if data.dtype == jax.numpy.bool_: + return jax.ops.segment_min( + data.astype(jax.numpy.int32), segment_ids, num_segments + ) + + # For numeric arrays, handle negative values and zeros + # Track zeros to set product to zero if any segment has zeros + is_zero = data == 0 + has_zeros = ( + jax.ops.segment_sum(is_zero.astype(jax.numpy.int32), segment_ids, num_segments) + > 0 + ) + + # Track signs to determine final sign of product + is_negative = data < 0 + neg_count = jax.ops.segment_sum( + is_negative.astype(jax.numpy.int32), segment_ids, num_segments + ) + sign_products = 1 - 2 * ( + neg_count % 2 + ) # +1 for even negatives, -1 for odd negatives + + # Calculate product of absolute values in log space + log_abs = jax.numpy.log(jax.numpy.where(is_zero, 1.0, jax.numpy.abs(data))) + log_products = jax.ops.segment_sum( + jax.numpy.where(is_zero, 0.0, log_abs), segment_ids, num_segments + ) + abs_products = jax.numpy.exp(log_products) + + # Apply zeros and signs + return jax.numpy.where(has_zeros, 0.0, sign_products * abs_products) + + @overloads(_reducers.Prod) class Prod(JAXReducer): name: Final = "prod" @@ -294,9 +411,7 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) # See issue https://github.com/google/jax/issues/9296 - result = jax.numpy.exp( - jax.ops.segment_sum(jax.numpy.log(array.data), parents.data) - ) + result = segment_prod_with_negatives(array.data, parents.data, outlength) if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( @@ -321,6 +436,25 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: def _return_dtype(cls, given_dtype): return np.bool_ + @staticmethod + def _max_initial(initial, type): + if initial is None: + if type in ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ): + return np.iinfo(type).min + else: + return -np.inf + + return initial + def apply( self, array: ak.contents.NumpyArray, @@ -330,7 +464,10 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_max(array.data, parents.data) + result = jax.ops.segment_max(array.data, parents.data, outlength) + if array.dtype is not np.dtype(bool): + result = result.at[result == 0].set(self._max_initial(None, array.dtype)) + result = result > self._max_initial(None, array.dtype) result = jax.numpy.asarray(result, dtype=bool) return ak.contents.NumpyArray(result, backend=array.backend) @@ -360,7 +497,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_min(array.data, parents.data) + result = jax.ops.segment_min(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=bool) return ak.contents.NumpyArray(result, backend=array.backend) @@ -413,7 +550,7 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_min(array.data, parents.data) + result = jax.ops.segment_min(array.data, parents.data, outlength) result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype)) if np.issubdtype(array.dtype, np.complexfloating): @@ -474,9 +611,9 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_max(array.data, parents.data) - + result = jax.ops.segment_max(array.data, parents.data, outlength) result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype)) + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( diff --git a/src/awkward/_reducers.py b/src/awkward/_reducers.py index 3ad690571a..3942fe98f8 100644 --- a/src/awkward/_reducers.py +++ b/src/awkward/_reducers.py @@ -280,7 +280,7 @@ def apply( class CountNonzero(KernelReducer): name: Final = "count_nonzero" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False def apply( diff --git a/tests/test_3464_jax_reducers.py b/tests/test_3464_jax_reducers.py new file mode 100644 index 0000000000..354cc5f707 --- /dev/null +++ b/tests/test_3464_jax_reducers.py @@ -0,0 +1,229 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy as np +import pytest + +import awkward as ak + +jax = pytest.importorskip("jax") +ak.jax.register_and_check() + +# Define all reducers to test +REDUCERS = [ + (ak.argmin, {}), + (ak.argmax, {}), + (ak.min, {}), + (ak.max, {}), + (ak.sum, {}), + (ak.prod, {"mask_identity": True}), # mask_identity for prod to handle empty arrays + (ak.any, {}), + (ak.all, {}), + (ak.count, {}), + (ak.count_nonzero, {}), +] + +# Define test arrays (single jagged) +SINGLE_JAGGED = [ + # Normal array + [[1, 2, 3], [4, 5], [6, 7, 8, 9]], + # Array with first empty + [[], [1, 2], [3, 4, 5]], + # Array with middle empty + [[1, 2], [], [3, 4, 5]], + # Array with last empty + [[1, 2], [3, 4, 5], []], + # Array with multiple empty elements + [[], [1, 2], [], [3, 4], []], + # Array with negative numbers + [[-1, -2], [-3], [4, 5, -6]], + # Array with zeros + [[0, 0], [1, 0], [2, 3, 4]], +] + +# Define test arrays (double jagged) +DOUBLE_JAGGED = [ + # Normal double jagged array + [[[1, 2], [3]], [[4, 5, 6]], [[7], [8, 9]]], + # Double jagged with empty at first level + [[], [[1, 2], [3, 4]], [[5, 6]]], + # Double jagged with empty at second level + [[[1, 2], []], [[3, 4], [5]], [[6]]], + # Double jagged with various empty elements + [[[]], [[], [1, 2]], [[], [], [3, 4]]], + # Double jagged with negative numbers + [[[-1, -2], [-3]], [[4, 5, -6]], [[7], [-8, 9]]], + # Double jagged with zeros + [[[0, 0], [1]], [[2, 3]], [[4], [5, 6]]], +] + +# Define axes to test +AXES = [1, None] # axis=1 for first dimension, None for flattened reduction +DOUBLE_JAGGED_AXES = [1, 2, None] # axis=1 and axis=2 for double jagged + +RTOL = 1e-5 # Relative tolerance for floating point comparison +ATOL = 1e-8 # Absolute tolerance for floating point comparison + + +def compare_results(cpu_list, jax_list): + """Compare results with tolerance for numeric values.""" + if isinstance(cpu_list, (int, float)) and isinstance(jax_list, (int, float)): + # Direct numeric comparison with tolerance + np.testing.assert_allclose(cpu_list, jax_list, rtol=RTOL, atol=ATOL) + elif isinstance(cpu_list, list) and isinstance(jax_list, list): + # Lists should have the same length + assert len(cpu_list) == len(jax_list), ( + f"Lists have different lengths: {len(cpu_list)} vs {len(jax_list)}" + ) + + # Compare each element + for cpu_item, jax_item in zip(cpu_list, jax_list): + compare_results(cpu_item, jax_item) + else: + # For non-numeric types, use exact equality + assert cpu_list == jax_list + + +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +@pytest.mark.parametrize("arr", SINGLE_JAGGED) +@pytest.mark.parametrize("axis", AXES) +def test_single_jagged_arrays(reducer, kwargs, arr, axis): + """Test reducers on single jagged arrays with different axes.""" + + # Create arrays with different backends + cpu_array = ak.Array(arr, backend="cpu") + jax_array = ak.Array(arr, backend="jax") + + # Apply reducers to each backend's array + cpu_result = reducer(cpu_array, axis=axis, **kwargs) + jax_result = reducer(jax_array, axis=axis, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where axis=None might result in different structures + if axis is None: + # If one result is a scalar and the other is a list with one element + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +@pytest.mark.parametrize("arr", DOUBLE_JAGGED) +@pytest.mark.parametrize("axis", DOUBLE_JAGGED_AXES) +def test_double_jagged_arrays(reducer, kwargs, arr, axis): + """Test reducers on double jagged arrays with different axes.""" + + # Create arrays with different backends + cpu_array = ak.Array(arr, backend="cpu") + jax_array = ak.Array(arr, backend="jax") + + # Apply reducers to each backend's array + cpu_result = reducer(cpu_array, axis=axis, **kwargs) + jax_result = reducer(jax_array, axis=axis, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where axis=None might result in different structures + if axis is None: + # If one result is a scalar and the other is a list with one element + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +# Additional edge cases +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +def test_all_empty_arrays(reducer, kwargs): + """Test with arrays that are entirely empty.""" + + all_empty_data = [[], [], []] + cpu_array = ak.Array(all_empty_data, backend="cpu") + jax_array = ak.Array(all_empty_data, backend="jax") + + cpu_result = reducer(cpu_array, axis=1, **kwargs) + jax_result = reducer(jax_array, axis=1, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where one might be a scalar and the other a list + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +# Test with boolean values +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +def test_boolean_arrays(reducer, kwargs): + """Test with boolean arrays.""" + + bool_data = [[True, False], [], [True, True, False], [False]] + cpu_array = ak.Array(bool_data, backend="cpu") + jax_array = ak.Array(bool_data, backend="jax") + + cpu_result = reducer(cpu_array, axis=1, **kwargs) + jax_result = reducer(jax_array, axis=1, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where one might be a scalar and the other a list + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list)