Skip to content

fix: temporarily use unary_transform instead of segmented_reduce#3814

Draft
maxymnaumchyk wants to merge 17 commits intoscikit-hep:mainfrom
maxymnaumchyk:maxymnaumchyk/unary-transform-for-kernels
Draft

fix: temporarily use unary_transform instead of segmented_reduce#3814
maxymnaumchyk wants to merge 17 commits intoscikit-hep:mainfrom
maxymnaumchyk:maxymnaumchyk/unary-transform-for-kernels

Conversation

@maxymnaumchyk
Copy link
Collaborator

No description provided.

@maxymnaumchyk
Copy link
Collaborator Author

maxymnaumchyk commented Jan 22, 2026

temporary replacement until NVIDIA/cccl#6171 is fixed. Also, a relevant PR with discussion: #3763

@codecov
Copy link

codecov bot commented Jan 22, 2026

Codecov Report

❌ Patch coverage is 0% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.67%. Comparing base (303bcdd) to head (7f11b0d).

Files with missing lines Patch % Lines
src/awkward/_connect/cuda/_compute.py 0.00% 25 Missing ⚠️
Additional details and impacted files
Files with missing lines Coverage Δ
src/awkward/_connect/cuda/_compute.py 0.00% <0.00%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link

The documentation preview is ready to be viewed at http://preview.awkward-array.org.s3-website.us-east-1.amazonaws.com/PR3814

@maxymnaumchyk
Copy link
Collaborator Author

Hello @shwina! I'm running into another error related to the NVIDIA/cccl#7121 fix. This time it's in the cuda/compute/numba_utils file.

The easiest way to reproduce this:

from __future__ import annotations
import numpy as np

from cuda.compute import (
    TransformIterator,
    CountingIterator,
)

def transform_segments():
    def col_major_index(j: np.int32) -> np.int32:
        return
    
    TransformIterator(CountingIterator(np.int32(0)), col_major_index)
    
transform_segments()

What is interesting I can't reproduce this error directly on the main branch of cccl. Has it already been fixed? Please, check it out.

Here is the full error:
NumbaNotImplementedError                  Traceback (most recent call last)
Cell In[5], line 1
----> 1 ak.argmax(awkward_array, axis = 1)

File ~/awkward/src/awkward/_dispatch.py:41, in named_high_level_function.<locals>.dispatch(*args, **kwargs)
     38 @wraps(func)
     39 def dispatch(*args, **kwargs):
     40     # NOTE: this decorator assumes that the operation is exposed under `ak.`
---> 41     with OperationErrorContext(name, args, kwargs):
     42         gen_or_result = func(*args, **kwargs)
     43         if isgenerator(gen_or_result):

File ~/awkward/src/awkward/_errors.py:80, in ErrorContext.__exit__(self, exception_type, exception_value, traceback)
     78     self._slate.__dict__.clear()
     79     # Handle caught exception
---> 80     raise self.decorate_exception(exception_type, exception_value)
     81 else:
     82     # Step out of the way so that another ErrorContext can become primary.
     83     if self.primary() is self:

File ~/awkward/src/awkward/_dispatch.py:67, in named_high_level_function.<locals>.dispatch(*args, **kwargs)
     65 # Failed to find a custom overload, so resume the original function
     66 try:
---> 67     next(gen_or_result)
     68 except StopIteration as err:
     69     return err.value

File ~/awkward/src/awkward/operations/ak_argmax.py:79, in argmax(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)
     76 yield (array,)
     78 # Implementation
---> 79 return _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)

File ~/awkward/src/awkward/operations/ak_argmax.py:163, in _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)
    159 axis = regularize_axis(axis, none_allowed=True)
    161 reducer = ak._reducers.ArgMax()
--> 163 out = ak._do.reduce(
    164     layout,
    165     reducer,
    166     axis=axis,
    167     mask=mask_identity,
    168     keepdims=keepdims,
    169     behavior=ctx.behavior,
    170 )
    172 wrapped_out = ctx.wrap(
    173     out,
    174     highlevel=highlevel,
    175     allow_other=True,
    176 )
    178 # propagate named axis to output

File ~/awkward/src/awkward/_do.py:326, in reduce(layout, reducer, axis, mask, keepdims, behavior)
    324 parents = ak.index.Index64.zeros(layout.length, layout.backend.nplike)
    325 shifts = None
--> 326 next = layout._reduce_next(
    327     reducer,
    328     negaxis,
    329     starts,
    330     shifts,
    331     parents,
    332     1,
    333     mask,
    334     keepdims,
    335     behavior,
    336 )
    338 return next[0]

File ~/awkward/src/awkward/contents/listoffsetarray.py:1633, in ListOffsetArray._reduce_next(self, reducer, negaxis, starts, shifts, parents, outlength, mask, keepdims, behavior)
   1630 trimmed = self._content[self.offsets[0] : self.offsets[-1]]
   1631 nextstarts = self.offsets[:-1]
-> 1633 outcontent = trimmed._reduce_next(
   1634     reducer,
   1635     negaxis,
   1636     nextstarts,
   1637     shifts,
   1638     nextparents,
   1639     globalstarts_length,
   1640     mask,
   1641     keepdims,
   1642     behavior,
   1643 )
   1645 outoffsets = Index64.empty(outlength + 1, nplike)
   1646 assert outoffsets.nplike is nplike and parents.nplike is nplike

File ~/awkward/src/awkward/contents/numpyarray.py:1165, in NumpyArray._reduce_next(self, reducer, negaxis, starts, shifts, parents, outlength, mask, keepdims, behavior)
   1162 assert self.is_contiguous
   1163 assert self._data.ndim == 1
-> 1165 out = reducer.apply(self, parents, starts, shifts, outlength)
   1167 if mask:
   1168     outmask = ak.index.Index8.empty(outlength, self._backend.nplike)

File ~/awkward/src/awkward/_reducers.py:239, in ArgMax.apply(self, array, parents, starts, shifts, outlength)
    236 else:
    237     assert parents.nplike is array.backend.nplike
    238     array.backend.maybe_kernel_error(
--> 239         array.backend[
    240             "awkward_reduce_argmax",
    241             result.dtype.type,
    242             kernel_array_data.dtype.type,
    243             parents.dtype.type,
    244         ](
    245             result,
    246             kernel_array_data,
    247             parents.data,
    248             parents.length,
    249             outlength,
    250         )
    251     )
    252 result_array = ak.contents.NumpyArray(result, backend=array.backend)
    253 apply_positional_corrections(result_array, parents, starts, shifts)

File ~/awkward/src/awkward/_kernels.py:234, in CudaComputeKernel.__call__(self, *args)
    232 def __call__(self, *args) -> None:
    233     args = maybe_materialize(*args)
--> 234     return self._impl(*args)

File ~/awkward/src/awkward/_connect/cuda/_compute.py:150, in awkward_reduce_argmax(result, input_data, parents_data, parents_length, outlength)
    145 # alternative way
    146 # _result = cp.zeros([outlength])
    147 
    148 # Perform the segmented reduce
    149 segment_ids = CountingIterator(cp.int64(0))
--> 150 unary_transform(segment_ids, _result, segment_reduce_op, outlength)
    152 # TODO: here converts float to int too, fix this?
    153 _result = _result.view(index_dtype).reshape(-1, 2)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:275, in unary_transform(d_in, d_out, op, num_items, stream)
    237 def unary_transform(
    238     d_in: DeviceArrayLike | IteratorBase,
    239     d_out: DeviceArrayLike | IteratorBase,
   (...)    242     stream=None,
    243 ):
    244     """
    245     Performs device-wide unary transform.
    246 
   (...)    273         stream: CUDA stream to use for the operation.
    274     """
--> 275     transformer = make_unary_transform(d_in, d_out, op)
    276     transformer(d_in, d_out, num_items, stream)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:202, in make_unary_transform(d_in, d_out, op)
    180 """
    181 Create a unary transform object that can be called to apply a transformation
    182 to each element of the input according to the unary operation ``op``.
   (...)    199     A callable object that performs the transformation.
    200 """
    201 op_adapter = make_op_adapter(op)
--> 202 return _make_unary_transform_cached(d_in, d_out, op_adapter)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/_caching.py:41, in cache_with_key.<locals>.deco.<locals>.inner(*args, **kwargs)
     39 cache_key = (key(*args, **kwargs), tuple(cc))
     40 if cache_key not in cache:
---> 41     result = func(*args, **kwargs)
     42     cache[cache_key] = result
     43 return cache[cache_key]

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:161, in _make_unary_transform_cached(d_in, d_out, op)
    154 @cache_with_key(_make_unary_transform_cache_key)
    155 def _make_unary_transform_cached(
    156     d_in: DeviceArrayLike | IteratorBase,
    157     d_out: DeviceArrayLike | IteratorBase,
    158     op: OpAdapter,
    159 ):
    160     """Internal cached factory for _UnaryTransform."""
--> 161     return _UnaryTransform(d_in, d_out, op)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:33, in _UnaryTransform.__init__(self, d_in, d_out, op)
     31 in_type = cccl.get_value_type(d_in)
     32 out_type = cccl.get_value_type(d_out)
---> 33 self.op_cccl = op.compile((in_type,), out_type)
     35 self.build_result = cccl.call_build(
     36     _bindings.DeviceUnaryTransform,
     37     self.d_in_cccl,
     38     self.d_out_cccl,
     39     self.op_cccl,
     40 )

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/op.py:94, in _StatelessOp.compile(self, input_types, output_type)
     92 # Try to get signature from annotations first
     93 try:
---> 94     sig = signature_from_annotations(self._func)
     95 except ValueError:
     96     # Infer signature from input/output types
     97     if output_type is None or (
     98         hasattr(output_type, "is_internal") and not output_type.is_internal
     99     ):

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/numba_utils.py:47, in signature_from_annotations(func)
     45 except KeyError:
     46     raise ValueError("Function has incomplete annotations: missing return type")
---> 47 retty = to_numba_type(ret_ann)
     48 if num_args != len(argspec.annotations) - 1:  # -1 for the return type
     49     raise ValueError("One or more arguments are missing type annotations")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/numba_utils.py:23, in to_numba_type(tp)
     21 if value := as_numba_type.lookup.get(tp):
     22     return value
---> 23 return numba.from_dtype(tp)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/np/numpy_support.py:137, in from_dtype(dtype)
    134         subtype = from_dtype(dtype.subdtype[0])
    135         return types.NestedArray(subtype, dtype.shape)
--> 137 raise errors.NumbaNotImplementedError(dtype)

NumbaNotImplementedError: index_dtype

This error occurred while calling

    ak.argmax(
        <Array [[1], [2, 3], [4, 5], ..., [], [9]] type='6 * var * int64'>
        axis = 1
    )

@shwina
Copy link
Contributor

shwina commented Jan 22, 2026

Fix in NVIDIA/cccl#7321. We'll push out a release today with this fix so that you don't have to work off of main. Thanks!

@maxymnaumchyk
Copy link
Collaborator Author

and thanks to you too!

@maxymnaumchyk
Copy link
Collaborator Author

maxymnaumchyk commented Jan 26, 2026

Hello @shwina! I'm getting another error :( Do you know what might cause it? I'm using cuda-cccl==0.4.5 and I didn't have it on the previous version.
image

Full error code if you need it:
TypingError                               Traceback (most recent call last)
Cell In[6], line 1
----> 1 get_ipython().run_cell_magic('timeit', '', 'result = ak.argmax(awkward_array, axis = 1)\n')

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/IPython/core/interactiveshell.py:2572, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2570 with self.builtin_trap:
   2571     args = (magic_arg_s, cell)
-> 2572     result = fn(*args, **kwargs)
   2574 # The code below prevents the output from being displayed
   2575 # when using magics with decorator @output_can_be_silenced
   2576 # when the last Python token in the expression is a ';'.
   2577 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/IPython/core/magics/execution.py:1222, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1220 for index in range(0, 10):
   1221     number = 10 ** index
-> 1222     time_number = timer.timeit(number)
   1223     if time_number >= 0.2:
   1224         break

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/IPython/core/magics/execution.py:184, in Timer.timeit(self, number)
    182 gc.disable()
    183 try:
--> 184     timing = self.inner(it, self.timer)
    185 finally:
    186     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

File ~/awkward/src/awkward/_dispatch.py:41, in named_high_level_function.<locals>.dispatch(*args, **kwargs)
     38 @wraps(func)
     39 def dispatch(*args, **kwargs):
     40     # NOTE: this decorator assumes that the operation is exposed under `ak.`
---> 41     with OperationErrorContext(name, args, kwargs):
     42         gen_or_result = func(*args, **kwargs)
     43         if isgenerator(gen_or_result):

File ~/awkward/src/awkward/_errors.py:80, in ErrorContext.__exit__(self, exception_type, exception_value, traceback)
     78     self._slate.__dict__.clear()
     79     # Handle caught exception
---> 80     raise self.decorate_exception(exception_type, exception_value)
     81 else:
     82     # Step out of the way so that another ErrorContext can become primary.
     83     if self.primary() is self:

File ~/awkward/src/awkward/_dispatch.py:67, in named_high_level_function.<locals>.dispatch(*args, **kwargs)
     65 # Failed to find a custom overload, so resume the original function
     66 try:
---> 67     next(gen_or_result)
     68 except StopIteration as err:
     69     return err.value

File ~/awkward/src/awkward/operations/ak_argmax.py:79, in argmax(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)
     76 yield (array,)
     78 # Implementation
---> 79 return _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)

File ~/awkward/src/awkward/operations/ak_argmax.py:163, in _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs)
    159 axis = regularize_axis(axis, none_allowed=True)
    161 reducer = ak._reducers.ArgMax()
--> 163 out = ak._do.reduce(
    164     layout,
    165     reducer,
    166     axis=axis,
    167     mask=mask_identity,
    168     keepdims=keepdims,
    169     behavior=ctx.behavior,
    170 )
    172 wrapped_out = ctx.wrap(
    173     out,
    174     highlevel=highlevel,
    175     allow_other=True,
    176 )
    178 # propagate named axis to output

File ~/awkward/src/awkward/_do.py:326, in reduce(layout, reducer, axis, mask, keepdims, behavior)
    324 parents = ak.index.Index64.zeros(layout.length, layout.backend.nplike)
    325 shifts = None
--> 326 next = layout._reduce_next(
    327     reducer,
    328     negaxis,
    329     starts,
    330     shifts,
    331     parents,
    332     1,
    333     mask,
    334     keepdims,
    335     behavior,
    336 )
    338 return next[0]

File ~/awkward/src/awkward/contents/listoffsetarray.py:1633, in ListOffsetArray._reduce_next(self, reducer, negaxis, starts, shifts, parents, outlength, mask, keepdims, behavior)
   1630 trimmed = self._content[self.offsets[0] : self.offsets[-1]]
   1631 nextstarts = self.offsets[:-1]
-> 1633 outcontent = trimmed._reduce_next(
   1634     reducer,
   1635     negaxis,
   1636     nextstarts,
   1637     shifts,
   1638     nextparents,
   1639     globalstarts_length,
   1640     mask,
   1641     keepdims,
   1642     behavior,
   1643 )
   1645 outoffsets = Index64.empty(outlength + 1, nplike)
   1646 assert outoffsets.nplike is nplike and parents.nplike is nplike

File ~/awkward/src/awkward/contents/numpyarray.py:1165, in NumpyArray._reduce_next(self, reducer, negaxis, starts, shifts, parents, outlength, mask, keepdims, behavior)
   1162 assert self.is_contiguous
   1163 assert self._data.ndim == 1
-> 1165 out = reducer.apply(self, parents, starts, shifts, outlength)
   1167 if mask:
   1168     outmask = ak.index.Index8.empty(outlength, self._backend.nplike)

File ~/awkward/src/awkward/_reducers.py:239, in ArgMax.apply(self, array, parents, starts, shifts, outlength)
    236 else:
    237     assert parents.nplike is array.backend.nplike
    238     array.backend.maybe_kernel_error(
--> 239         array.backend[
    240             "awkward_reduce_argmax",
    241             result.dtype.type,
    242             kernel_array_data.dtype.type,
    243             parents.dtype.type,
    244         ](
    245             result,
    246             kernel_array_data,
    247             parents.data,
    248             parents.length,
    249             outlength,
    250         )
    251     )
    252 result_array = ak.contents.NumpyArray(result, backend=array.backend)
    253 apply_positional_corrections(result_array, parents, starts, shifts)

File ~/awkward/src/awkward/_kernels.py:234, in CudaComputeKernel.__call__(self, *args)
    232 def __call__(self, *args) -> None:
    233     args = maybe_materialize(*args)
--> 234     return self._impl(*args)

File ~/awkward/src/awkward/_connect/cuda/_compute.py:145, in awkward_reduce_argmax(result, input_data, parents_data, parents_length, outlength)
    143 segment_ids = CountingIterator(type_wrapper(0))
    144 # TODO: try using segmented_reduce instead when https://github.com/NVIDIA/cccl/issues/6171 is fixed
--> 145 unary_transform(segment_ids, result, segment_reduce_op, outlength)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:224, in unary_transform(d_in, d_out, op, num_items, stream)
    186 def unary_transform(
    187     d_in: DeviceArrayLike | IteratorBase,
    188     d_out: DeviceArrayLike | IteratorBase,
   (...)    191     stream=None,
    192 ):
    193     """
    194     Performs device-wide unary transform.
    195 
   (...)    222         stream: CUDA stream to use for the operation.
    223     """
--> 224     transformer = make_unary_transform(d_in, d_out, op)
    225     transformer(d_in, d_out, num_items, stream)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/_caching.py:112, in _CacheWithRegisteredKeyFunctions.__call__.<locals>.inner(*args, **kwargs)
    110 cache_key = (user_cache_key, tuple(cc))
    111 if cache_key not in cache:
--> 112     result = func(*args, **kwargs)
    113     cache[cache_key] = result
    114 return cache[cache_key]

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:150, in make_unary_transform(d_in, d_out, op)
    128 """
    129 Create a unary transform object that can be called to apply a transformation
    130 to each element of the input according to the unary operation ``op``.
   (...)    147     A callable object that performs the transformation.
    148 """
    149 op_adapter = make_op_adapter(op)
--> 150 return _UnaryTransform(d_in, d_out, op_adapter)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/algorithms/_transform.py:33, in _UnaryTransform.__init__(self, d_in, d_out, op)
     31 in_type = cccl.get_value_type(d_in)
     32 out_type = cccl.get_value_type(d_out)
---> 33 self.op_cccl = op.compile((in_type,), out_type)
     35 self.build_result = cccl.call_build(
     36     _bindings.DeviceUnaryTransform,
     37     self.d_in_cccl,
     38     self.d_out_cccl,
     39     self.op_cccl,
     40 )

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/op.py:93, in _StatelessOp.compile(self, input_types, output_type)
     90         output_type = get_inferred_return_type(self._func, input_types)
     91     sig = output_type(*input_types)
---> 93 return cccl.to_stateless_cccl_op(self._func, sig)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/_cccl_interop.py:235, in to_stateless_cccl_op(op, sig)
    231 from ._odr_helpers import create_op_void_ptr_wrapper
    233 wrapped_op, wrapper_sig = create_op_void_ptr_wrapper(op, sig)
--> 235 ltoir, _ = cuda.compile(wrapped_op, sig=wrapper_sig, output="ltoir")
    236 return Op(
    237     operator_type=OpKind.STATELESS,
    238     name=wrapped_op.__name__,
   (...)    241     state=None,
    242 )

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/compiler.py:755, in compile(pyfunc, sig, debug, lineinfo, device, fastmath, cc, opt, abi, abi_info, output, forceinline, launch_bounds)
    752 MIN_CC = max(config.CUDA_DEFAULT_PTX_CC, nvrtc.get_lowest_supported_cc())
    753 cc = cc or MIN_CC
--> 755 cres = compile_cuda(
    756     pyfunc,
    757     return_type,
    758     args,
    759     debug=debug,
    760     lineinfo=lineinfo,
    761     fastmath=fastmath,
    762     nvvm_options=nvvm_options,
    763     cc=cc,
    764     forceinline=forceinline,
    765 )
    766 resty = cres.signature.return_type
    768 if resty and not device and resty != types.void:

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/compiler.py:406, in compile_cuda(pyfunc, return_type, args, debug, lineinfo, forceinline, fastmath, nvvm_options, cc, max_registers, lto)
    403 from numba.core.target_extension import target_override
    405 with target_override("cuda"):
--> 406     cres = compiler.compile_extra(
    407         typingctx=typingctx,
    408         targetctx=targetctx,
    409         func=pyfunc,
    410         args=args,
    411         return_type=return_type,
    412         flags=flags,
    413         locals={},
    414         pipeline_class=CUDACompiler,
    415     )
    417 library = cres.library
    418 library.finalize()

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:739, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    715 """Compiler entry point
    716 
    717 Parameter
   (...)    735     compiler pipeline
    736 """
    737 pipeline = pipeline_class(typingctx, targetctx, library,
    738                           args, return_type, flags, locals)
--> 739 return pipeline.compile_extra(func)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py:133, in CompilerBase.compile_extra(self, func)
    131 self.state.lifted = ()
    132 self.state.lifted_from = None
--> 133 return self._compile_bytecode()

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py:201, in CompilerBase._compile_bytecode(self)
    197 """
    198 Populate and run pipeline for bytecode input
    199 """
    200 assert self.state.func_ir is None
--> 201 return self._compile_core()

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py:180, in CompilerBase._compile_core(self)
    178         self.state.status.fail_reason = e
    179         if is_final_pipeline:
--> 180             raise e
    181 else:
    182     raise CompilerError("All available pipelines exhausted")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py:169, in CompilerBase._compile_core(self)
    167 res = None
    168 try:
--> 169     pm.run(self.state)
    170     if self.state.cr is not None:
    171         break

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:367, in PassManager.run(self, state)
    364 msg = "Failed in %s mode pipeline (step: %s)" % \
    365     (self.pipeline_name, pass_desc)
    366 patched_exception = self._patch_error(msg, e)
--> 367 raise patched_exception

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:272, in PassManager._runPass.<locals>.check(func, compiler_state)
    271 def check(func, compiler_state):
--> 272     mangled = func(compiler_state)
    273     if mangled not in (True, False):
    274         msg = ("CompilerPass implementations should return True/False. "
    275                "CompilerPass with name '%s' did not.")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typed_passes.py:98, in BaseNativeLowering.run_pass(self, state)
     94 with targetctx.push_code_library(library):
     95     lower = self.lowering_class(
     96         targetctx, library, fndesc, interp, metadata=metadata
     97     )
---> 98     lower.lower()
     99     if not flags.no_cpython_wrapper:
    100         lower.create_cpython_wrapper(flags.release_gil)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:234, in BaseLower.lower(self)
    232 if self.generator_info is None:
    233     self.genlower = None
--> 234     self.lower_normal_function(self.fndesc)
    235 else:
    236     self.genlower = self.GeneratorLower(self)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:273, in BaseLower.lower_normal_function(self, fndesc)
    271 # Init argument values
    272 self.extract_function_arguments()
--> 273 entry_block_tail = self.lower_function_body()
    275 # Close tail of entry block, do not emit debug metadata else the
    276 # unconditional jump gets associated with the metadata from the function
    277 # body end.
    278 with debuginfo.suspend_emission(self.builder):

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:304, in BaseLower.lower_function_body(self)
    302     self.builder.position_at_end(bb)
    303     self.debug_print(f"# lower block: {offset}")
--> 304     self.lower_block(block)
    305 self.post_lower()
    306 return entry_block_tail

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:322, in BaseLower.lower_block(self, block)
    315     defaulterrcls = partial(LoweringError, loc=self.loc)
    316     with new_error_context(
    317         'lowering "{inst}" at {loc}',
    318         inst=inst,
    319         loc=self.loc,
    320         errcls_=defaulterrcls,
    321     ):
--> 322         self.lower_inst(inst)
    323 self.post_block(block)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:522, in Lower.lower_inst(self, inst)
    520 if isinstance(inst, ir.Assign):
    521     ty = self.typeof(inst.target.name)
--> 522     val = self.lower_assign(ty, inst)
    523     argidx = None
    524     # If this is a store from an arg, like x = arg.x then tell debuginfo
    525     # that this is the arg

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:750, in Lower.lower_assign(self, ty, inst)
    747     return res
    749 elif isinstance(value, ir.Expr):
--> 750     return self.lower_expr(ty, value)
    752 elif isinstance(value, ir.Var):
    753     val = self.loadvar(value.name)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:1422, in Lower.lower_expr(self, resty, expr)
   1419     return res
   1421 elif expr.op == "call":
-> 1422     res = self.lower_call(resty, expr)
   1423     return res
   1425 elif expr.op == "pair_first":

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:1036, in Lower.lower_call(self, resty, expr)
   1033     res = self._lower_call_FunctionType(fnty, expr, signature)
   1035 else:
-> 1036     res = self._lower_call_normal(fnty, expr, signature)
   1038 # If lowering the call returned None, interpret that as returning dummy
   1039 # value if the return type of the function is void, otherwise there is
   1040 # a problem
   1041 if res is None:

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/lowering.py:1392, in Lower._lower_call_normal(self, fnty, expr, signature)
   1389     # Prepend the self reference
   1390     argvals = [the_self] + list(argvals)
-> 1392 res = impl(self.builder, argvals, self.loc)
   1393 return res

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/base.py:1190, in _wrap_impl.__call__(self, builder, args, loc)
   1189 def __call__(self, builder, args, loc=None):
-> 1190     res = self._imp(self._context, builder, self._sig, args, loc=loc)
   1191     self._context.add_linking_libs(getattr(self, 'libs', ()))
   1192     return res

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/base.py:1220, in _wrap_missing_loc.__call__.<locals>.wrapper(*args, **kwargs)
   1218 def wrapper(*args, **kwargs):
   1219     kwargs.pop('loc')     # drop unused loc
-> 1220     return fn(*args, **kwargs)

File <string>:5, in codegen(context, builder, impl_sig, args)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/cuda/compute/_odr_helpers.py:117, in _codegen_void_ptr_wrapper(context, builder, args, arg_specs, func_device, inner_sig)
    114             raise ValueError(f"Invalid arg mode: {spec.mode}")
    116 # Call the inner function
--> 117 cres = context.compile_subroutine(builder, func_device, inner_sig, caching=False)
    118 result = context.call_internal(builder, cres.fndesc, inner_sig, input_vals)
    120 # Store result if needed

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/base.py:866, in BaseContext.compile_subroutine(self, builder, impl, sig, locals, flags, caching)
    864     cached = self.cached_internal_func.get(cache_key)
    865 if cached is None:
--> 866     cres = self._compile_subroutine_no_cache(builder, impl, sig,
    867                                              locals=locals,
    868                                              flags=flags)
    869     self.cached_internal_func[cache_key] = cres
    871 cres = self.cached_internal_func[cache_key]

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba_cuda/numba/cuda/target.py:339, in CUDATargetContext._compile_subroutine_no_cache(self, builder, impl, sig, locals, flags)
    336 flags.no_cpython_wrapper = True
    337 flags.no_cfunc_wrapper = True
--> 339 cres = compiler.compile_internal(
    340     self.typing_context,
    341     self,
    342     library,
    343     impl,
    344     sig.args,
    345     sig.return_type,
    346     flags,
    347     locals=locals,
    348 )
    350 # Allow inlining the function inside callers
    351 self.active_code_library.add_linking_library(cres.library)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:813, in compile_internal(typingctx, targetctx, library, func, args, return_type, flags, locals)
    808 """
    809 For internal use only.
    810 """
    811 pipeline = Compiler(typingctx, targetctx, library,
    812                     args, return_type, flags, locals)
--> 813 return pipeline.compile_extra(func)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:439, in CompilerBase.compile_extra(self, func)
    437 self.state.lifted = ()
    438 self.state.lifted_from = None
--> 439 return self._compile_bytecode()

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:505, in CompilerBase._compile_bytecode(self)
    501 """
    502 Populate and run pipeline for bytecode input
    503 """
    504 assert self.state.func_ir is None
--> 505 return self._compile_core()

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:484, in CompilerBase._compile_core(self)
    482         self.state.status.fail_reason = e
    483         if is_final_pipeline:
--> 484             raise e
    485 else:
    486     raise CompilerError("All available pipelines exhausted")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler.py:473, in CompilerBase._compile_core(self)
    471 res = None
    472 try:
--> 473     pm.run(self.state)
    474     if self.state.cr is not None:
    475         break

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:367, in PassManager.run(self, state)
    364 msg = "Failed in %s mode pipeline (step: %s)" % \
    365     (self.pipeline_name, pass_desc)
    366 patched_exception = self._patch_error(msg, e)
--> 367 raise patched_exception

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/compiler_machinery.py:272, in PassManager._runPass.<locals>.check(func, compiler_state)
    271 def check(func, compiler_state):
--> 272     mangled = func(compiler_state)
    273     if mangled not in (True, False):
    274         msg = ("CompilerPass implementations should return True/False. "
    275                "CompilerPass with name '%s' did not.")

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typed_passes.py:91, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     88 for k, v in locals.items():
     89     infer.seed_type(k, v)
---> 91 infer.build_constraint()
     92 # return errors in case of partial typing
     93 errs = infer.propagate(raise_errors=raise_errors)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typeinfer.py:1027, in TypeInferer.build_constraint(self)
   1025 for blk in self.blocks.values():
   1026     for inst in blk.body:
-> 1027         self.constrain_statement(inst)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typeinfer.py:1389, in TypeInferer.constrain_statement(self, inst)
   1387 def constrain_statement(self, inst):
   1388     if isinstance(inst, ir.Assign):
-> 1389         self.typeof_assign(inst)
   1390     elif isinstance(inst, ir.SetItem):
   1391         self.typeof_setitem(inst)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typeinfer.py:1464, in TypeInferer.typeof_assign(self, inst)
   1461     self.constraints.append(Propagate(dst=inst.target.name,
   1462                                       src=value.name, loc=inst.loc))
   1463 elif isinstance(value, (ir.Global, ir.FreeVar)):
-> 1464     self.typeof_global(inst, inst.target, value)
   1465 elif isinstance(value, ir.Arg):
   1466     self.typeof_arg(inst, inst.target, value)

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typeinfer.py:1564, in TypeInferer.typeof_global(self, inst, target, gvar)
   1562 def typeof_global(self, inst, target, gvar):
   1563     try:
-> 1564         typ = self.resolve_value_type(inst, gvar.value)
   1565     except TypingError as e:
   1566         if (gvar.name == self.func_id.func_name
   1567                 and gvar.name in _temporary_dispatcher_map):
   1568             # Self-recursion case where the dispatcher is not (yet?) known
   1569             # as a global variable

File ~/miniforge3/envs/ak3/lib/python3.13/site-packages/numba/core/typeinfer.py:1485, in TypeInferer.resolve_value_type(self, inst, val)
   1483 except ValueError as e:
   1484     msg = str(e)
-> 1485 raise TypingError(msg, loc=inst.loc)

TypingError: Failed in cuda mode pipeline (step: cuda native lowering)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'start_o': Cannot determine Numba type of <class 'cupy.ndarray'>

File "../../../awkward/src/awkward/_connect/cuda/_compute.py", line 127:
    def segment_reduce_op(segment_id):
        start_idx = start_o[segment_id]
        ^

During: Pass nopython_type_inference
During: lowering "$16call.4 = call $4load_global.0(arg_0, arg_1, func=$4load_global.0, args=[Var(arg_0, <string>:8), Var(arg_1, <string>:8)], kws=(), vararg=None, varkwarg=None, target=None)" at <string> (9)
During: Pass cuda_native_lowering

This error occurred while calling

    ak.argmax(
        <Array [[1], [2, 3], [4, 5], ..., [], [9]] type='6 * var * int64'>
        axis = 1
    )

@shwina
Copy link
Contributor

shwina commented Jan 26, 2026

@maxymnaumchyk -- thanks, could you tell me how to reproduce what you're seeing?

I tried the following (on your branch):

nox -s prepare
pip install -e . -v
pytest ./tests-cuda/test_3136_cuda_argmin_and_argmax.py

and it seemed to complete without errors.

@shwina
Copy link
Contributor

shwina commented Jan 26, 2026

I can see I'm clearly missing a step since CI is failing :)

@shwina
Copy link
Contributor

shwina commented Jan 26, 2026

Oh - I see the problem. CI is pulling in numba-cuda==0.19.*. However, the ability to reference device arrays in device functions was added much more recently - NVIDIA/numba-cuda#666.

Can you try with the constraint numba-cuda>=0.23.0?

edit: In the mean time I'll update our own constraints.

@maxymnaumchyk
Copy link
Collaborator Author

yes, thanks Ashwin! It was indeed the problem in the versions of my packages. I'll take a deeper look into it tomorrow~

@maxymnaumchyk
Copy link
Collaborator Author

maxymnaumchyk commented Jan 28, 2026

Hello @shwina. There currently is a conflict between cudf and numba-cuda versions:
image

I see there was a commit in cudf updating the requirements, but it's not in a public release yet:
image

Do you perhaps know when the next release is planned? We can also skip the cudf tests for now.

@shwina
Copy link
Contributor

shwina commented Jan 28, 2026

@maxymnaumchyk for now, can you bypass conda and install cuda-cccl using pip?

It's cheating, but unfortunately RAPIDS won't be relaxing their numba-cuda pins until their next release :(

Alternately, if you think this is more appropriate, that's fine with me too:

We can also skip the cudf tests for now.

@maxymnaumchyk
Copy link
Collaborator Author

Right now, this implementation works, except very slowly. Running this script:

import awkward as ak
import cupy as cp
import timeit

awkward_array = ak.Array([[1], [2, 3], [4, 5], [6, 7, 1, 8], [], [9]], backend = 'cuda')
# first time, ak.argmax:
_ = ak.argmax(awkward_array, axis=1)  # warmup
start_time = timeit.default_timer()
for i in range(10):
    expect = ak.argmax(awkward_array, axis=1)
    cp.cuda.Device().synchronize()
end_time = timeit.default_timer()
print(f"Time taken for ak.argmax: {(end_time - start_time) / 10} seconds")

Shows:

Time taken for ak.argmax: 2.027083551697433 seconds

@shwina
Copy link
Contributor

shwina commented Feb 2, 2026

The issue here is that the offsets is recomputed each time the function runs. This leads to segmented_reduce_op function being recompiled each time the algorithm is invoked (because it's referencing a different global array each time).

One fix would be to pass the same offsets every time. That would help cache reuse across reductions on the same array, but it's still not ideal as you'll get recompilation for every new array.

We have a better solution for this in cuda.compute brewing at the moment. I'll get back to you here when it's ready!

@maxymnaumchyk
Copy link
Collaborator Author

maxymnaumchyk commented Feb 2, 2026

Thanks @shwina! That's good to know. Meanwhile, I'll try to figure out how to pass offsets to kernels directly, instead of calculating them from parents.

@shwina
Copy link
Contributor

shwina commented Feb 5, 2026

With the latest cuda-cccl (0.5.0) you should see the perf hiccup go away:

import awkward as ak
import cupy as cp
import timeit

awkward_array = ak.Array([[1], [2, 3], [4, 5], [6, 7, 1, 8], [], [9]], backend = 'cuda')
# first time, ak.argmax:
_ = ak.argmax(awkward_array, axis=1)  # warmup
start_time = timeit.default_timer()
for i in range(10):
    expect = ak.argmax(awkward_array, axis=1)
    cp.cuda.Device().synchronize()
end_time = timeit.default_timer()
print(f"Time taken for ak.argmax: {(end_time - start_time) / 10} seconds")
Time taken for ak.argmax: 0.0011918096977751702 seconds

@maxymnaumchyk
Copy link
Collaborator Author

awesome! do you have a planned release?

@shwina
Copy link
Contributor

shwina commented Feb 5, 2026

awesome! do you have a planned release?

Yes, should be available on pip/conda now! Thanks!

@maxymnaumchyk
Copy link
Collaborator Author

maxymnaumchyk commented Feb 6, 2026

Hello @shwina, there is currently a bug(?) with how some_unary_op is handled inside the cuda.compute.make_unary_transform(d_input, d_output, some_unary_op) algorithm. If I call different unary_operations with the same names and inputs but make them return different numpy operations they still return the same outputs. What I mean is, if I currently call

import awkward as ak

awkward_array = ak.Array([[1], [2, 3], [4, 5], [6, 7, 1, 8], [], [9]], backend = 'cuda')
print(ak.argmax(awkward_array))
print(ak.argmin(awkward_array))

argmin will use the precompiled segment_reduce_op from argmax and return the same results. How could I fix this? One solution would be to have a different name for segment_reduce_op for each awkward kernel.

@shwina
Copy link
Contributor

shwina commented Feb 6, 2026

Thanks @maxymnaumchyk. Yes it's definitely a bug. I'm looking into it.

@maxymnaumchyk
Copy link
Collaborator Author

Also, for some reason I can't get it work as fast as you. For example running this script:

import awkward as ak
import cupy as cp
import timeit

awkward_array = ak.Array([[1], [2, 3], [4, 5], [6, 7, 1, 8], [], [9]], backend = 'cuda')
# first time, ak.argmax:
_ = ak.argmax(awkward_array, axis=1)  # warmup
start_time = timeit.default_timer()
for i in range(10):
    expect = ak.argmax(awkward_array, axis=1)
    cp.cuda.Device().synchronize()
end_time = timeit.default_timer()
print(f"Time taken for ak.argmax: {(end_time - start_time) / 10} seconds")

Returns
Time taken for ak.argmax: 0.014016678486950696 seconds
10 times as long as yours. Is it dependent on my machine?

Here is the result from the Nsight Systems profiler if it helps:

import nvtx
import awkward as ak
import cupy as cp
import timeit


awkward_array = ak.Array([[1], [2, 3], [4, 5], [6, 7, 1, 8], [], [9]], backend = 'cuda')
_ = ak.argmax(awkward_array, axis=1)  # warmup
with nvtx.annotate("running argmax..."):
    ak.argmax(awkward_array, axis = 1)
image

cudaMemcpyAsync means that it copies data between host and device?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants