From c5d85e36c0abce5821a68f184d12a9ae62d5780c Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 19:09:03 -0500 Subject: [PATCH 01/12] add tests and use outlength --- src/awkward/_connect/jax/reducers.py | 35 ++-- tests/test_X_jax_reducers.py | 232 +++++++++++++++++++++++++++ 2 files changed, 248 insertions(+), 19 deletions(-) create mode 100644 tests/test_X_jax_reducers.py diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 4d9b0344c6..9bd09c3144 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -37,18 +37,18 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: raise NotImplementedError -def segment_argmin(data, segment_ids): +def segment_argmin(data, segment_ids, num_segments): """ Applies a segmented argmin-style reduction. Parameters: data: jax.numpy.ndarray — the values to reduce. segment_ids: same shape as data — indicates segment groupings. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — indices of min within each segment. """ - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 indices = jax.numpy.arange(data.shape[0]) # Find the minimum value in each segment @@ -88,24 +88,24 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_argmin(array.data, parents.data) + result = segment_argmin(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=array.dtype) return ak.contents.NumpyArray(result, backend=array.backend) -def segment_argmax(data, segment_ids): +def segment_argmax(data, segment_ids, num_segments): """ Applies a segmented argmax-style reduction. Parameters: data: jax.numpy.ndarray — the values to reduce. segment_ids: same shape as data — indicates segment groupings. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — indices of max within each segment. """ - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 indices = jax.numpy.arange(data.shape[0]) # Find the maximum value in each segment @@ -145,7 +145,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_argmax(array.data, parents.data) + result = segment_argmax(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=array.dtype) return ak.contents.NumpyArray(result, backend=array.backend) @@ -176,7 +176,7 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) result = jax.numpy.ones_like(array.data, dtype=array.dtype) - result = jax.ops.segment_sum(result, parents.data) + result = jax.ops.segment_sum(result, parents.data, outlength) if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( @@ -186,21 +186,18 @@ def apply( return ak.contents.NumpyArray(result, backend=array.backend) -def segment_count_nonzero(data, segment_ids, num_segments=None): +def segment_count_nonzero(data, segment_ids, num_segments): """ Counts the number of non-zero elements in `data` per segment. Parameters: data: jax.numpy.ndarray — input values to count. segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. - num_segments: int (optional) — total number of segments. + num_segments: int — total number of segments. Returns: jax.numpy.ndarray — count of non-zero values per segment. """ - if num_segments is None: - num_segments = int(jax.numpy.max(segment_ids).item()) + 1 - # Create a binary mask where non-zero entries become 1 nonzero_mask = jax.numpy.where(data != 0, 1, 0) @@ -232,7 +229,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = segment_count_nonzero(array.data, parents.data) + result = segment_count_nonzero(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=self.preferred_dtype) return ak.contents.NumpyArray(result, backend=array.backend) @@ -261,7 +258,7 @@ def apply( if array.dtype.kind == "M": raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}") - result = jax.ops.segment_sum(array.data, parents.data) + result = jax.ops.segment_sum(array.data, parents.data, outlength) if array.dtype.kind == "m": return ak.contents.NumpyArray( @@ -295,7 +292,7 @@ def apply( assert isinstance(array, ak.contents.NumpyArray) # See issue https://github.com/google/jax/issues/9296 result = jax.numpy.exp( - jax.ops.segment_sum(jax.numpy.log(array.data), parents.data) + jax.ops.segment_sum(jax.numpy.log(array.data), parents.data, outlength) ) if np.issubdtype(array.dtype, np.complexfloating): @@ -360,7 +357,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_min(array.data, parents.data) + result = jax.ops.segment_min(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=bool) return ak.contents.NumpyArray(result, backend=array.backend) @@ -413,7 +410,7 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_min(array.data, parents.data) + result = jax.ops.segment_min(array.data, parents.data, outlength) result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype)) if np.issubdtype(array.dtype, np.complexfloating): @@ -474,9 +471,9 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_max(array.data, parents.data) - + result = jax.ops.segment_max(array.data, parents.data, outlength) result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype)) + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( diff --git a/tests/test_X_jax_reducers.py b/tests/test_X_jax_reducers.py new file mode 100644 index 0000000000..a6e8bb5a82 --- /dev/null +++ b/tests/test_X_jax_reducers.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import awkward as ak + +ak.jax.register_and_check() + +# Define all reducers to test +REDUCERS = [ + (ak.argmin, {}), + (ak.argmax, {}), + (ak.min, {}), + (ak.max, {}), + (ak.sum, {}), + (ak.prod, {"mask_identity": True}), # mask_identity for prod to handle empty arrays + (ak.any, {}), + (ak.all, {}), + (ak.count, {}), + (ak.count_nonzero, {}), +] + +# Define test arrays (single jagged) +SINGLE_JAGGED = [ + # Normal array + [[1, 2, 3], [4, 5], [6, 7, 8, 9]], + # Array with first empty + [[], [1, 2], [3, 4, 5]], + # Array with middle empty + [[1, 2], [], [3, 4, 5]], + # Array with last empty + [[1, 2], [3, 4, 5], []], + # Array with multiple empty elements + [[], [1, 2], [], [3, 4], []], + # Array with negative numbers + [[-1, -2], [-3], [4, 5, -6]], +] + +# Define test arrays (double jagged) +DOUBLE_JAGGED = [ + # Normal double jagged array + [[[1, 2], [3]], [[4, 5, 6]], [[7], [8, 9]]], + # Double jagged with empty at first level + [[], [[1, 2], [3, 4]], [[5, 6]]], + # Double jagged with empty at second level + [[[1, 2], []], [[3, 4], [5]], [[6]]], + # Double jagged with various empty elements + [[[]], [[], [1, 2]], [[], [], [3, 4]]], +] + +# Define axes to test +AXES = [1, None] # axis=1 for first dimension, None for flattened reduction +DOUBLE_JAGGED_AXES = [1, 2, None] # axis=1 and axis=2 for double jagged + +RTOL = 1e-5 # Relative tolerance for floating point comparison +ATOL = 1e-8 # Absolute tolerance for floating point comparison + + +def compare_results(cpu_list, jax_list): + """Compare results with tolerance for numeric values.""" + if isinstance(cpu_list, (int, float)) and isinstance(jax_list, (int, float)): + # Direct numeric comparison with tolerance + np.testing.assert_allclose(cpu_list, jax_list, rtol=RTOL, atol=ATOL) + elif isinstance(cpu_list, list) and isinstance(jax_list, list): + # Lists should have the same length + assert len(cpu_list) == len(jax_list), ( + f"Lists have different lengths: {len(cpu_list)} vs {len(jax_list)}" + ) + + # Compare each element + for cpu_item, jax_item in zip(cpu_list, jax_list): + compare_results(cpu_item, jax_item) + else: + # For non-numeric types, use exact equality + assert cpu_list == jax_list + + +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +@pytest.mark.parametrize("arr", SINGLE_JAGGED) +@pytest.mark.parametrize("axis", AXES) +def test_single_jagged_arrays(reducer, kwargs, arr, axis): + """Test reducers on single jagged arrays with different axes.""" + # Skip argmin and argmax tests + if reducer in [ak.argmin, ak.argmax]: + pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") + + # Create arrays with different backends + cpu_array = ak.Array(arr, backend="cpu") + jax_array = ak.Array(arr, backend="jax") + + # Apply reducers to each backend's array + cpu_result = reducer(cpu_array, axis=axis, **kwargs) + jax_result = reducer(jax_array, axis=axis, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where axis=None might result in different structures + if axis is None: + # If one result is a scalar and the other is a list with one element + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +@pytest.mark.parametrize("arr", DOUBLE_JAGGED) +@pytest.mark.parametrize("axis", DOUBLE_JAGGED_AXES) +def test_double_jagged_arrays(reducer, kwargs, arr, axis): + """Test reducers on double jagged arrays with different axes.""" + # Skip argmin and argmax tests + if reducer in [ak.argmin, ak.argmax]: + pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") + + # Create arrays with different backends + cpu_array = ak.Array(arr, backend="cpu") + jax_array = ak.Array(arr, backend="jax") + + # Apply reducers to each backend's array + cpu_result = reducer(cpu_array, axis=axis, **kwargs) + jax_result = reducer(jax_array, axis=axis, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where axis=None might result in different structures + if axis is None: + # If one result is a scalar and the other is a list with one element + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +# Additional edge cases +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +def test_all_empty_arrays(reducer, kwargs): + """Test with arrays that are entirely empty.""" + # Skip argmin and argmax tests + if reducer in [ak.argmin, ak.argmax]: + pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") + + all_empty_data = [[], [], []] + cpu_array = ak.Array(all_empty_data, backend="cpu") + jax_array = ak.Array(all_empty_data, backend="jax") + + cpu_result = reducer(cpu_array, axis=1, **kwargs) + jax_result = reducer(jax_array, axis=1, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where one might be a scalar and the other a list + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) + + +# Test with boolean values +@pytest.mark.parametrize("reducer,kwargs", REDUCERS) +def test_boolean_arrays(reducer, kwargs): + """Test with boolean arrays.""" + # Skip argmin and argmax tests + if reducer in [ak.argmin, ak.argmax]: + pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") + + bool_data = [[True, False], [], [True, True, False], [False]] + cpu_array = ak.Array(bool_data, backend="cpu") + jax_array = ak.Array(bool_data, backend="jax") + + cpu_result = reducer(cpu_array, axis=1, **kwargs) + jax_result = reducer(jax_array, axis=1, **kwargs) + + # Convert to lists for comparison + cpu_list = ak.to_list(cpu_result) + jax_list = ak.to_list(jax_result) + + # Handle case where one might be a scalar and the other a list + if ( + not isinstance(cpu_list, list) + and isinstance(jax_list, list) + and len(jax_list) == 1 + ): + jax_list = jax_list[0] + elif ( + isinstance(cpu_list, list) + and not isinstance(jax_list, list) + and len(cpu_list) == 1 + ): + cpu_list = cpu_list[0] + + # Compare with tolerance for numeric values + compare_results(cpu_list, jax_list) From 1f74feb472bd77825ac447a21d0503d28583b505 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 19:16:38 -0500 Subject: [PATCH 02/12] rename test and skip if jax is not installed --- src/awkward/_connect/jax/reducers.py | 2 +- tests/{test_X_jax_reducers.py => test_3464_jax_reducers.py} | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) rename tests/{test_X_jax_reducers.py => test_3464_jax_reducers.py} (98%) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 9bd09c3144..6587e4298d 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -327,7 +327,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.ops.segment_max(array.data, parents.data) + result = jax.ops.segment_max(array.data, parents.data, outlength) result = jax.numpy.asarray(result, dtype=bool) return ak.contents.NumpyArray(result, backend=array.backend) diff --git a/tests/test_X_jax_reducers.py b/tests/test_3464_jax_reducers.py similarity index 98% rename from tests/test_X_jax_reducers.py rename to tests/test_3464_jax_reducers.py index a6e8bb5a82..279f884ca8 100644 --- a/tests/test_X_jax_reducers.py +++ b/tests/test_3464_jax_reducers.py @@ -1,3 +1,5 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + from __future__ import annotations import numpy as np @@ -5,6 +7,7 @@ import awkward as ak +jax = pytest.importorskip("jax") ak.jax.register_and_check() # Define all reducers to test From 4215ef9b74de5468773a40ffd5b1705afe9ddaf8 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 21:02:36 -0500 Subject: [PATCH 03/12] count and argmax should be returning integers --- src/awkward/_connect/jax/reducers.py | 5 +++-- src/awkward/_reducers.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 6587e4298d..db5bf6e4e2 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -154,7 +154,7 @@ def apply( @overloads(_reducers.Count) class Count(JAXReducer): name: Final = "count" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False @classmethod @@ -177,6 +177,7 @@ def apply( assert isinstance(array, ak.contents.NumpyArray) result = jax.numpy.ones_like(array.data, dtype=array.dtype) result = jax.ops.segment_sum(result, parents.data, outlength) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( @@ -208,7 +209,7 @@ def segment_count_nonzero(data, segment_ids, num_segments): @overloads(_reducers.CountNonzero) class CountNonzero(JAXReducer): name: Final = "count_nonzero" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False @classmethod diff --git a/src/awkward/_reducers.py b/src/awkward/_reducers.py index 3ad690571a..3942fe98f8 100644 --- a/src/awkward/_reducers.py +++ b/src/awkward/_reducers.py @@ -280,7 +280,7 @@ def apply( class CountNonzero(KernelReducer): name: Final = "count_nonzero" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False def apply( From 7e681af10e2ba93ca70c3414ba4cc81c5375c86f Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 21:15:52 -0500 Subject: [PATCH 04/12] maybe make ak.any cpu and jax behavior match? --- src/awkward/_connect/jax/reducers.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index db5bf6e4e2..fe2d38586b 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -154,7 +154,7 @@ def apply( @overloads(_reducers.Count) class Count(JAXReducer): name: Final = "count" - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 needs_position: Final = False @classmethod @@ -209,7 +209,7 @@ def segment_count_nonzero(data, segment_ids, num_segments): @overloads(_reducers.CountNonzero) class CountNonzero(JAXReducer): name: Final = "count_nonzero" - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 needs_position: Final = False @classmethod @@ -319,6 +319,25 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: def _return_dtype(cls, given_dtype): return np.bool_ + @staticmethod + def _max_initial(initial, type): + if initial is None: + if type in ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ): + return np.iinfo(type).min + else: + return -np.inf + + return initial + def apply( self, array: ak.contents.NumpyArray, @@ -329,6 +348,8 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) result = jax.ops.segment_max(array.data, parents.data, outlength) + if array.dtype is not np.dtype(bool): + result = result > self._max_initial(None, array.dtype) result = jax.numpy.asarray(result, dtype=bool) return ak.contents.NumpyArray(result, backend=array.backend) From 51a0945c1df099953c38371d1597b5ade0c6024d Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 21:35:14 -0500 Subject: [PATCH 05/12] skip product checks with negative values --- tests/test_3464_jax_reducers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_3464_jax_reducers.py b/tests/test_3464_jax_reducers.py index 279f884ca8..c6e5ec2562 100644 --- a/tests/test_3464_jax_reducers.py +++ b/tests/test_3464_jax_reducers.py @@ -92,6 +92,9 @@ def test_single_jagged_arrays(reducer, kwargs, arr, axis): cpu_array = ak.Array(arr, backend="cpu") jax_array = ak.Array(arr, backend="jax") + if reducer is ak.prod and ak.any(cpu_array < 0): + pytest.skip("Jax prod does not support negative values") + # Apply reducers to each backend's array cpu_result = reducer(cpu_array, axis=axis, **kwargs) jax_result = reducer(jax_array, axis=axis, **kwargs) @@ -133,6 +136,9 @@ def test_double_jagged_arrays(reducer, kwargs, arr, axis): cpu_array = ak.Array(arr, backend="cpu") jax_array = ak.Array(arr, backend="jax") + if reducer is ak.prod and ak.any(cpu_array < 0): + pytest.skip("Jax prod does not support negative values") + # Apply reducers to each backend's array cpu_result = reducer(cpu_array, axis=axis, **kwargs) jax_result = reducer(jax_array, axis=axis, **kwargs) From 9874810d9f1686e3f4a3402ca522be74fe601334 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 22:12:20 -0500 Subject: [PATCH 06/12] segment product that deals with negative and zero values too --- src/awkward/_connect/jax/reducers.py | 65 ++++++++++++++++++++++++++-- tests/test_3464_jax_reducers.py | 12 ++--- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index fe2d38586b..0ba2712056 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -271,6 +271,67 @@ def apply( return ak.contents.NumpyArray(result, backend=array.backend) +def segment_prod_with_negatives(data, segment_ids, num_segments): + """ + Computes the product of elements in each segment, handling negatives and booleans. + Parameters: + data: jax.numpy.ndarray — input values to reduce. + segment_ids: jax.numpy.ndarray — same shape as data, segment assignment. + num_segments: int — total number of segments. + Returns: + jax.numpy.ndarray — product of values per segment. + """ + # Convert boolean arrays to integers if needed + if data.dtype == jax.numpy.bool_: + # For booleans, the product is just whether ALL values are True + # We can use segment_min for this (since True=1, False=0, and prod is 0 if ANY is False) + return jax.ops.segment_min( + data.astype(jax.numpy.int32), segment_ids, num_segments + ) + + # Extract signs + signs = jax.numpy.sign(data) + abs_data = jax.numpy.abs(data) + + # Compute product of absolute values in log space + # Handle zeros separately to avoid log(0) + zeros_mask = abs_data == 0 + has_zeros = ( + jax.ops.segment_sum( + zeros_mask.astype(jax.numpy.int32), segment_ids, num_segments + ) + > 0 + ) + + # For non-zero values, use log-sum-exp + safe_abs_data = jax.numpy.where( + zeros_mask, 1.0, abs_data + ) # Replace zeros with ones for log + log_abs = jax.numpy.log(safe_abs_data) + summed_logs = jax.ops.segment_sum( + jax.numpy.where(zeros_mask, 0.0, log_abs), segment_ids, num_segments + ) + abs_products = jax.numpy.exp(summed_logs) + + # If any segment has a zero, its product is zero + abs_products = jax.numpy.where(has_zeros, 0.0, abs_products) + + # Calculate product of signs separately + sign_products = ( + jax.ops.segment_sum( + (signs < 0).astype(jax.numpy.int32), segment_ids, num_segments + ) + % 2 + ) + sign_products = 1 - 2 * sign_products # Convert to +1/-1 + + # Zeros should have sign 0, not -1 or 1 + sign_products = jax.numpy.where(has_zeros, 0.0, sign_products) + + # Combine signs with absolute products + return sign_products * abs_products + + @overloads(_reducers.Prod) class Prod(JAXReducer): name: Final = "prod" @@ -292,9 +353,7 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) # See issue https://github.com/google/jax/issues/9296 - result = jax.numpy.exp( - jax.ops.segment_sum(jax.numpy.log(array.data), parents.data, outlength) - ) + result = segment_prod_with_negatives(array.data, parents.data, outlength) if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( diff --git a/tests/test_3464_jax_reducers.py b/tests/test_3464_jax_reducers.py index c6e5ec2562..0c86394b82 100644 --- a/tests/test_3464_jax_reducers.py +++ b/tests/test_3464_jax_reducers.py @@ -38,6 +38,8 @@ [[], [1, 2], [], [3, 4], []], # Array with negative numbers [[-1, -2], [-3], [4, 5, -6]], + # Array with zeros + [[0, 0], [1, 0], [2, 3, 4]], ] # Define test arrays (double jagged) @@ -50,6 +52,10 @@ [[[1, 2], []], [[3, 4], [5]], [[6]]], # Double jagged with various empty elements [[[]], [[], [1, 2]], [[], [], [3, 4]]], + # Double jagged with negative numbers + [[[-1, -2], [-3]], [[4, 5, -6]], [[7], [-8, 9]]], + # Double jagged with zeros + [[[0, 0], [1]], [[2, 3]], [[4], [5, 6]]], ] # Define axes to test @@ -92,9 +98,6 @@ def test_single_jagged_arrays(reducer, kwargs, arr, axis): cpu_array = ak.Array(arr, backend="cpu") jax_array = ak.Array(arr, backend="jax") - if reducer is ak.prod and ak.any(cpu_array < 0): - pytest.skip("Jax prod does not support negative values") - # Apply reducers to each backend's array cpu_result = reducer(cpu_array, axis=axis, **kwargs) jax_result = reducer(jax_array, axis=axis, **kwargs) @@ -136,9 +139,6 @@ def test_double_jagged_arrays(reducer, kwargs, arr, axis): cpu_array = ak.Array(arr, backend="cpu") jax_array = ak.Array(arr, backend="jax") - if reducer is ak.prod and ak.any(cpu_array < 0): - pytest.skip("Jax prod does not support negative values") - # Apply reducers to each backend's array cpu_result = reducer(cpu_array, axis=axis, **kwargs) jax_result = reducer(jax_array, axis=axis, **kwargs) From 28432d684d1ead502dc900e5897fe8215fe09a70 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 13 Apr 2025 22:26:31 -0500 Subject: [PATCH 07/12] simplify the function --- src/awkward/_connect/jax/reducers.py | 57 ++++++++++------------------ 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 0ba2712056..2b0688b7a9 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -281,55 +281,38 @@ def segment_prod_with_negatives(data, segment_ids, num_segments): Returns: jax.numpy.ndarray — product of values per segment. """ - # Convert boolean arrays to integers if needed + # Handle boolean arrays if data.dtype == jax.numpy.bool_: - # For booleans, the product is just whether ALL values are True - # We can use segment_min for this (since True=1, False=0, and prod is 0 if ANY is False) return jax.ops.segment_min( data.astype(jax.numpy.int32), segment_ids, num_segments ) - # Extract signs - signs = jax.numpy.sign(data) - abs_data = jax.numpy.abs(data) - - # Compute product of absolute values in log space - # Handle zeros separately to avoid log(0) - zeros_mask = abs_data == 0 + # For numeric arrays, handle negative values and zeros + # Track zeros to set product to zero if any segment has zeros + is_zero = data == 0 has_zeros = ( - jax.ops.segment_sum( - zeros_mask.astype(jax.numpy.int32), segment_ids, num_segments - ) + jax.ops.segment_sum(is_zero.astype(jax.numpy.int32), segment_ids, num_segments) > 0 ) - # For non-zero values, use log-sum-exp - safe_abs_data = jax.numpy.where( - zeros_mask, 1.0, abs_data - ) # Replace zeros with ones for log - log_abs = jax.numpy.log(safe_abs_data) - summed_logs = jax.ops.segment_sum( - jax.numpy.where(zeros_mask, 0.0, log_abs), segment_ids, num_segments + # Track signs to determine final sign of product + is_negative = data < 0 + neg_count = jax.ops.segment_sum( + is_negative.astype(jax.numpy.int32), segment_ids, num_segments ) - abs_products = jax.numpy.exp(summed_logs) - - # If any segment has a zero, its product is zero - abs_products = jax.numpy.where(has_zeros, 0.0, abs_products) - - # Calculate product of signs separately - sign_products = ( - jax.ops.segment_sum( - (signs < 0).astype(jax.numpy.int32), segment_ids, num_segments - ) - % 2 + sign_products = 1 - 2 * ( + neg_count % 2 + ) # +1 for even negatives, -1 for odd negatives + + # Calculate product of absolute values in log space + log_abs = jax.numpy.log(jax.numpy.where(is_zero, 1.0, jax.numpy.abs(data))) + log_products = jax.ops.segment_sum( + jax.numpy.where(is_zero, 0.0, log_abs), segment_ids, num_segments ) - sign_products = 1 - 2 * sign_products # Convert to +1/-1 - - # Zeros should have sign 0, not -1 or 1 - sign_products = jax.numpy.where(has_zeros, 0.0, sign_products) + abs_products = jax.numpy.exp(log_products) - # Combine signs with absolute products - return sign_products * abs_products + # Apply zeros and signs + return jax.numpy.where(has_zeros, 0.0, sign_products * abs_products) @overloads(_reducers.Prod) From 64239296260b2c2a0ed65cb75e50eb114facdd12 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Mon, 14 Apr 2025 02:25:40 -0500 Subject: [PATCH 08/12] my first attempt at apply_positional_corrections for jax --- src/awkward/_connect/jax/reducers.py | 85 +++++++++++++++++++++++++--- tests/test_3464_jax_reducers.py | 12 ---- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 2b0688b7a9..e1af81dfdc 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -37,6 +37,69 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: raise NotImplementedError +def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): + if outlength == 0 or parents.size == 0: + return toptr + sub_toptr = toptr[:outlength] + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + valid = sub_toptr != identity + safe_sub_toptr = jax.numpy.where(valid, sub_toptr, 0) + safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) + parent_indices = parents[safe_sub_toptr_int] + adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] + updated = jax.numpy.where(valid, sub_toptr - adjustments, sub_toptr) + return jax.numpy.concatenate([updated, toptr[outlength:]]) + + +def awkward_JAXArray_reduce_adjust_starts_shifts_64( + toptr, outlength, parents, starts, shifts +): + if outlength == 0 or parents.size == 0: + return toptr + sub_toptr = toptr[:outlength] + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + valid = sub_toptr != identity + safe_sub_toptr = jax.numpy.where(valid, sub_toptr, 0) + safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) + parent_indices = parents[safe_sub_toptr_int] + delta = ( + shifts[safe_sub_toptr_int] + - starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] + ) + updated = jax.numpy.where(valid, sub_toptr + delta, sub_toptr) + return jax.numpy.concatenate([updated, toptr[outlength:]]) + + +def apply_positional_corrections( + reduced: ak.contents.NumpyArray, + parents: ak.index.Index, + starts: ak.index.Index, + shifts: ak.index.Index | None, +) -> ak._nplikes.ArrayLike: + if shifts is None: + assert ( + parents.nplike is reduced.backend.nplike + and starts.nplike is reduced.backend.nplike + ) + return awkward_JAXArray_reduce_adjust_starts_64( + reduced.data, reduced.length, parents.data, starts.data + ) + + else: + assert ( + parents.nplike is reduced.backend.nplike + and starts.nplike is reduced.backend.nplike + and shifts.nplike is reduced.backend.nplike + ) + return awkward_JAXArray_reduce_adjust_starts_shifts_64( + reduced.data, + reduced.length, + parents.data, + starts.data, + shifts.data, + ) + + def segment_argmin(data, segment_ids, num_segments): """ Applies a segmented argmin-style reduction. @@ -68,7 +131,7 @@ def segment_argmin(data, segment_ids, num_segments): class ArgMin(JAXReducer): name: Final = "argmin" needs_position: Final = True - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -89,9 +152,12 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) result = segment_argmin(array.data, parents.data, outlength) - result = jax.numpy.asarray(result, dtype=array.dtype) - - return ak.contents.NumpyArray(result, backend=array.backend) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) + result_array = ak.contents.NumpyArray(result, backend=array.backend) + corrected_data = apply_positional_corrections( + result_array, parents, starts, shifts + ) + return ak.contents.NumpyArray(corrected_data, backend=array.backend) def segment_argmax(data, segment_ids, num_segments): @@ -125,7 +191,7 @@ def segment_argmax(data, segment_ids, num_segments): class ArgMax(JAXReducer): name: Final = "argmax" needs_position: Final = True - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -146,9 +212,12 @@ def apply( ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) result = segment_argmax(array.data, parents.data, outlength) - result = jax.numpy.asarray(result, dtype=array.dtype) - - return ak.contents.NumpyArray(result, backend=array.backend) + result = jax.numpy.asarray(result, dtype=self.preferred_dtype) + result_array = ak.contents.NumpyArray(result, backend=array.backend) + corrected_data = apply_positional_corrections( + result_array, parents, starts, shifts + ) + return ak.contents.NumpyArray(corrected_data, backend=array.backend) @overloads(_reducers.Count) diff --git a/tests/test_3464_jax_reducers.py b/tests/test_3464_jax_reducers.py index 0c86394b82..354cc5f707 100644 --- a/tests/test_3464_jax_reducers.py +++ b/tests/test_3464_jax_reducers.py @@ -90,9 +90,6 @@ def compare_results(cpu_list, jax_list): @pytest.mark.parametrize("axis", AXES) def test_single_jagged_arrays(reducer, kwargs, arr, axis): """Test reducers on single jagged arrays with different axes.""" - # Skip argmin and argmax tests - if reducer in [ak.argmin, ak.argmax]: - pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") # Create arrays with different backends cpu_array = ak.Array(arr, backend="cpu") @@ -131,9 +128,6 @@ def test_single_jagged_arrays(reducer, kwargs, arr, axis): @pytest.mark.parametrize("axis", DOUBLE_JAGGED_AXES) def test_double_jagged_arrays(reducer, kwargs, arr, axis): """Test reducers on double jagged arrays with different axes.""" - # Skip argmin and argmax tests - if reducer in [ak.argmin, ak.argmax]: - pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") # Create arrays with different backends cpu_array = ak.Array(arr, backend="cpu") @@ -171,9 +165,6 @@ def test_double_jagged_arrays(reducer, kwargs, arr, axis): @pytest.mark.parametrize("reducer,kwargs", REDUCERS) def test_all_empty_arrays(reducer, kwargs): """Test with arrays that are entirely empty.""" - # Skip argmin and argmax tests - if reducer in [ak.argmin, ak.argmax]: - pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") all_empty_data = [[], [], []] cpu_array = ak.Array(all_empty_data, backend="cpu") @@ -208,9 +199,6 @@ def test_all_empty_arrays(reducer, kwargs): @pytest.mark.parametrize("reducer,kwargs", REDUCERS) def test_boolean_arrays(reducer, kwargs): """Test with boolean arrays.""" - # Skip argmin and argmax tests - if reducer in [ak.argmin, ak.argmax]: - pytest.skip(f"Skipping {reducer.__name__} as it's not fully supported") bool_data = [[True, False], [], [True, True, False], [False]] cpu_array = ak.Array(bool_data, backend="cpu") From bfae90171d8d0a93754f01289a8cf1ac5cfbd2ee Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Mon, 14 Apr 2025 08:48:43 -0500 Subject: [PATCH 09/12] .at is better than concatenate --- src/awkward/_connect/jax/reducers.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index e1af81dfdc..12c53be7f1 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -40,15 +40,16 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): if outlength == 0 or parents.size == 0: return toptr - sub_toptr = toptr[:outlength] + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) - valid = sub_toptr != identity - safe_sub_toptr = jax.numpy.where(valid, sub_toptr, 0) + valid = toptr[:outlength] != identity + safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) parent_indices = parents[safe_sub_toptr_int] adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] - updated = jax.numpy.where(valid, sub_toptr - adjustments, sub_toptr) - return jax.numpy.concatenate([updated, toptr[outlength:]]) + updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength]) + + return toptr.at[:outlength].set(updated) def awkward_JAXArray_reduce_adjust_starts_shifts_64( @@ -56,18 +57,19 @@ def awkward_JAXArray_reduce_adjust_starts_shifts_64( ): if outlength == 0 or parents.size == 0: return toptr - sub_toptr = toptr[:outlength] + identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) - valid = sub_toptr != identity - safe_sub_toptr = jax.numpy.where(valid, sub_toptr, 0) + valid = toptr[:outlength] != identity + safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) parent_indices = parents[safe_sub_toptr_int] delta = ( shifts[safe_sub_toptr_int] - starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] ) - updated = jax.numpy.where(valid, sub_toptr + delta, sub_toptr) - return jax.numpy.concatenate([updated, toptr[outlength:]]) + updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength]) + + return toptr.at[:outlength].set(updated) def apply_positional_corrections( From 32cf4e9df8744b1a4c511da651de27db9f3152c6 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Mon, 14 Apr 2025 09:08:55 -0500 Subject: [PATCH 10/12] fix cases like ak.any([[0, 0], [1, 0], [2, 3, 4]]) maybe? --- src/awkward/_connect/jax/reducers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 12c53be7f1..4d608c56fa 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -462,6 +462,7 @@ def apply( assert isinstance(array, ak.contents.NumpyArray) result = jax.ops.segment_max(array.data, parents.data, outlength) if array.dtype is not np.dtype(bool): + result = result.at[result == 0].set(self._max_initial(None, array.dtype)) result = result > self._max_initial(None, array.dtype) result = jax.numpy.asarray(result, dtype=bool) From c35c7d46e707a1fa402f70c90f44aa12c1655ceb Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Mon, 14 Apr 2025 09:22:37 -0500 Subject: [PATCH 11/12] fix sum and count with bools? --- src/awkward/_connect/jax/reducers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 4d608c56fa..2a356e18e0 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -246,7 +246,7 @@ def apply( outlength: ShapeItem, ) -> ak.contents.NumpyArray: assert isinstance(array, ak.contents.NumpyArray) - result = jax.numpy.ones_like(array.data, dtype=array.dtype) + result = jax.numpy.ones_like(array.data, dtype=self.preferred_dtype) result = jax.ops.segment_sum(result, parents.data, outlength) result = jax.numpy.asarray(result, dtype=self.preferred_dtype) @@ -330,7 +330,11 @@ def apply( if array.dtype.kind == "M": raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}") - result = jax.ops.segment_sum(array.data, parents.data, outlength) + if array.dtype.kind == "b": + input_array = array.data.astype(np.int64) + else: + input_array = array.data + result = jax.ops.segment_sum(input_array, parents.data, outlength) if array.dtype.kind == "m": return ak.contents.NumpyArray( From bdfa95345d64f275b7bfbdf7eba21e1d3de31ce0 Mon Sep 17 00:00:00 2001 From: Iason Krommydas Date: Sun, 20 Apr 2025 20:44:04 -0500 Subject: [PATCH 12/12] make jax reducers return the same dtype as the numpy ones at the cost of differentiability --- src/awkward/_connect/jax/reducers.py | 43 ++++++++++---------- src/awkward/_reducers.py | 2 +- tests/test_1490_jax_reducers_combinations.py | 35 ---------------- tests/test_2638_mean_and_count_grads.py | 9 ---- 4 files changed, 23 insertions(+), 66 deletions(-) diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 2a356e18e0..927370ec15 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -37,36 +37,31 @@ def from_kernel_reducer(cls, reducer: Reducer) -> Self: raise NotImplementedError -def awkward_JAXArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): +def awkward_JAXNumpyArray_reduce_adjust_starts_64(toptr, outlength, parents, starts): if outlength == 0 or parents.size == 0: return toptr - identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + identity = jax.numpy.iinfo(jax.numpy.int64).max valid = toptr[:outlength] != identity safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) - safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) - parent_indices = parents[safe_sub_toptr_int] - adjustments = starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] + parent_indices = parents[safe_sub_toptr] + adjustments = starts[parent_indices] updated = jax.numpy.where(valid, toptr[:outlength] - adjustments, toptr[:outlength]) return toptr.at[:outlength].set(updated) -def awkward_JAXArray_reduce_adjust_starts_shifts_64( +def awkward_JAXNumpyArray_reduce_adjust_starts_shifts_64( toptr, outlength, parents, starts, shifts ): if outlength == 0 or parents.size == 0: return toptr - identity = jax.numpy.astype(jax.numpy.iinfo(jax.numpy.int64).max, toptr.dtype) + identity = jax.numpy.iinfo(jax.numpy.int64).max valid = toptr[:outlength] != identity safe_sub_toptr = jax.numpy.where(valid, toptr[:outlength], 0) - safe_sub_toptr_int = jax.numpy.astype(safe_sub_toptr, jax.numpy.int64) - parent_indices = parents[safe_sub_toptr_int] - delta = ( - shifts[safe_sub_toptr_int] - - starts[jax.numpy.astype(parent_indices, jax.numpy.int64)] - ) + parent_indices = parents[safe_sub_toptr] + delta = shifts[safe_sub_toptr] - starts[parent_indices] updated = jax.numpy.where(valid, toptr[:outlength] + delta, toptr[:outlength]) return toptr.at[:outlength].set(updated) @@ -83,7 +78,7 @@ def apply_positional_corrections( parents.nplike is reduced.backend.nplike and starts.nplike is reduced.backend.nplike ) - return awkward_JAXArray_reduce_adjust_starts_64( + return awkward_JAXNumpyArray_reduce_adjust_starts_64( reduced.data, reduced.length, parents.data, starts.data ) @@ -93,7 +88,7 @@ def apply_positional_corrections( and starts.nplike is reduced.backend.nplike and shifts.nplike is reduced.backend.nplike ) - return awkward_JAXArray_reduce_adjust_starts_shifts_64( + return awkward_JAXNumpyArray_reduce_adjust_starts_shifts_64( reduced.data, reduced.length, parents.data, @@ -133,7 +128,7 @@ def segment_argmin(data, segment_ids, num_segments): class ArgMin(JAXReducer): name: Final = "argmin" needs_position: Final = True - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -193,7 +188,7 @@ def segment_argmax(data, segment_ids, num_segments): class ArgMax(JAXReducer): name: Final = "argmax" needs_position: Final = True - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 @classmethod def from_kernel_reducer(cls, reducer: Reducer) -> Self: @@ -225,7 +220,7 @@ def apply( @overloads(_reducers.Count) class Count(JAXReducer): name: Final = "count" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False @classmethod @@ -280,7 +275,7 @@ def segment_count_nonzero(data, segment_ids, num_segments): @overloads(_reducers.CountNonzero) class CountNonzero(JAXReducer): name: Final = "count_nonzero" - preferred_dtype: Final = np.float64 + preferred_dtype: Final = np.int64 needs_position: Final = False @classmethod @@ -387,13 +382,19 @@ def segment_prod_with_negatives(data, segment_ids, num_segments): abs_products = jax.numpy.exp(log_products) # Apply zeros and signs - return jax.numpy.where(has_zeros, 0.0, sign_products * abs_products) + product = jax.numpy.where(has_zeros, 0.0, sign_products * abs_products) + # floating point accuracy doesn't let us directly cast to integers + if np.issubdtype(data.dtype, np.integer): + result = jax.numpy.round(product).astype(data.dtype) + else: + result = product + return result @overloads(_reducers.Prod) class Prod(JAXReducer): name: Final = "prod" - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 needs_position: Final = False @classmethod diff --git a/src/awkward/_reducers.py b/src/awkward/_reducers.py index 3942fe98f8..f0d5e9ea38 100644 --- a/src/awkward/_reducers.py +++ b/src/awkward/_reducers.py @@ -437,7 +437,7 @@ def apply( class Prod(KernelReducer): name: Final = "prod" - preferred_dtype: Final = np.int64 + preferred_dtype: Final = np.float64 needs_position: Final = False def apply( diff --git a/tests/test_1490_jax_reducers_combinations.py b/tests/test_1490_jax_reducers_combinations.py index 02745b04b5..d415200bc4 100644 --- a/tests/test_1490_jax_reducers_combinations.py +++ b/tests/test_1490_jax_reducers_combinations.py @@ -71,41 +71,6 @@ def func_jax_with_axis(x): ) -@pytest.mark.parametrize("axis", [0, 1, None]) -@pytest.mark.parametrize("func_ak", [ak.argmin, ak.argmax, ak.count_nonzero]) -def test_int_output_reducer(func_ak, axis): - func_jax = getattr(jax.numpy, func_ak.__name__) - - def func_ak_with_axis(x): - return func_ak(x, axis=axis) - - def func_jax_with_axis(x): - return func_jax(x, axis=axis) - - value_jvp, jvp_grad = jax.jvp( - func_ak_with_axis, (test_regulararray,), (test_regulararray_tangent,) - ) - value_jvp_jax, jvp_grad_jax = jax.jvp( - func_jax_with_axis, (test_regulararray_jax,), (test_regulararray_tangent_jax,) - ) - - value_vjp, vjp_func = jax.vjp(func_ak_with_axis, test_regulararray) - value_vjp_jax, vjp_func_jax = jax.vjp(func_jax_with_axis, test_regulararray_jax) - - numpy.testing.assert_allclose( - ak.to_list(value_jvp), value_jvp_jax.tolist(), rtol=1e-9, atol=np.inf - ) - numpy.testing.assert_allclose( - ak.to_list(value_vjp), value_vjp_jax.tolist(), rtol=1e-9, atol=np.inf - ) - numpy.testing.assert_allclose( - ak.to_list(vjp_func(value_vjp)[0]), - (vjp_func_jax(value_vjp_jax)[0]).tolist(), - rtol=1e-9, - atol=np.inf, - ) - - @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("func_ak", [ak.sort]) def test_sort(func_ak, axis): diff --git a/tests/test_2638_mean_and_count_grads.py b/tests/test_2638_mean_and_count_grads.py index 12c7934e86..4b33267623 100644 --- a/tests/test_2638_mean_and_count_grads.py +++ b/tests/test_2638_mean_and_count_grads.py @@ -16,17 +16,8 @@ def test(): val_mean, grad_mean = jax.value_and_grad(ak.mean, argnums=0)(array) _, grad_sum = jax.value_and_grad(ak.sum, argnums=0)(array) - val_count, grad_count = jax.value_and_grad(ak.count, argnums=0)(array) assert val_mean == 3 assert ak.all( grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax") ) - - # mean is treated as scaled sum - assert ak.all(grad_mean == grad_sum / val_count) - - assert val_count == 5 - assert ak.all( - grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax") - )