Skip to content

Conversation

@shwina
Copy link
Contributor

@shwina shwina commented Dec 11, 2025

Description

This PR adds support for ops that can capture state. This can be used to implement algorithms wide side-effects:

from cuda.compute import select
import numpy as np
import cupy as cp
import numba.cuda as numba_cuda


# Create device state for counting rejected items
reject_counter = cp.zeros(1, dtype=np.int32)

# Define condition that references state as closure
def count_rejects(x):
    if x > 50:
        return True
    else:
        numba_cuda.atomic.add(reject_counter, 0, 1)
        return False

d_in = cp.arange(100, dtype=np.int32)
d_out = cp.empty_like(d_in)
d_num_selected = cp.empty(2, dtype=np.uint64)

select(
    d_in,
    d_out,
    d_num_selected,
    count_rejects,
    len(d_in),
)

print(reject_counter)

prints:

[51]

It works by extending numba-CUDA to recognize device-array-like objects, and teaching it how to capture them from enclosing scope using lower_constant.

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@shwina shwina requested a review from a team as a code owner December 11, 2025 21:08
@github-project-automation github-project-automation bot moved this to Todo in CCCL Dec 11, 2025
@cccl-authenticator-app cccl-authenticator-app bot moved this from Todo to In Review in CCCL Dec 11, 2025
@shwina shwina changed the title Add stateful ops redux Add support for ops that capture device arrays from enclosing scope Dec 11, 2025
Copy link
Contributor

@bdice bdice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll want some expert review here - but I commend you on the progress so far! This is a super cool feature. Closures over device arrays in numba-cuda would be a huge win.

return hash(self._identity)
except TypeError:
# if we can't hash the identity, use its id
return id(self._identity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend hashing the resulting id value.

Suggested change
return id(self._identity)
return hash(id(self._identity))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In CPython these two things give identical values, although spec does not require them to be:

Python 3.13.3 | packaged by conda-forge | (main, Apr 14 2025, 20:44:03) [GCC 13.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 9.4.0 -- An enhanced Interactive Python. Type '?' for help.
Tip: You can use LaTeX or Unicode completion, `\alpha<tab>` will insert the α symbol.

In [1]: o = list()

In [2]: id(o), hash(id(o))
Out[2]: (138960802540160, 138960802540160)

In [3]: %timeit hash(id(o))
36.7 ns ± 0.17 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [4]: %timeit id(o)
21.7 ns ± 0.909 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)



# =============================================================================
# Step 5: Register lower_constant for our type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These step numbers don't align with the 1, 2, 3 at the top of the file. Maybe these comments aren't essential?

@github-actions

This comment has been minimized.

def __init__(self, dtype, ndim, layout, readonly=False, aligned=True):
type_name = "device_array"
if readonly:
type_name = "readonly " + type_name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the parent class, "unaligned" also appears in the name if aligned is False. Is there a guarantee that things of DeviceArrayType are aligned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe this is a question for @leofang ? Do CAI (and DLPack) make any guarantees about alignment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think either one guarantees or constrains about alignment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've defaulted to align=False and added "unaligned" to the type name.

_original_generic = _original_typeof_impl.dispatch(object)


def _patched_generic(val, c):
Copy link

@gmarkall gmarkall Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba-CUDA already has a mechanism for typing CAI object, but it lived in the CUDA dispatcher: https://github.com/NVIDIA/numba-cuda/blob/94321646a046ea5dd1b64e8dc078e27121552e23/numba_cuda/numba/cuda/dispatcher.py#L1626-L1643

This was because typeof came from upstream Numba, but I think we can push this check into typeof in Numba-CUDA now that it contains its own typeof implementation.

# =============================================================================


def _typeof_device_array(val, c):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When typing an argument, this will be called with c.purpose == Purpose.constant (from numba.cuda.typing.typeof.Purpose). Perhaps we don't want to affect typing of arguments - in which case we should return None when c.purpose == Purpose.argument.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also - I haven't read through this implementation, but it's probably faster than the implementation in Dispatcher.typeof_pyval, so we should probably try if upstreaming it into typeof_pyval to see if it makes kernel launch a bit faster for CAI arrays (cc @cpcloud)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines 131 to 146
@register_default(DeviceArrayType)
class DeviceArrayModel(StructModel):
"""Data model for DeviceArrayType - same as regular Array."""

def __init__(self, dmm, fe_type):
ndim = fe_type.ndim
members = [
("meminfo", types.MemInfoPointer(fe_type.dtype)),
("parent", types.pyobject),
("nitems", types.intp),
("itemsize", types.intp),
("data", types.CPointer(fe_type.dtype)),
("shape", types.UniTuple(types.intp, ndim)),
("strides", types.UniTuple(types.intp, ndim)),
]
super().__init__(dmm, fe_type, members)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks the same as the ArrayModel, so I think we could just register it:

Suggested change
@register_default(DeviceArrayType)
class DeviceArrayModel(StructModel):
"""Data model for DeviceArrayType - same as regular Array."""
def __init__(self, dmm, fe_type):
ndim = fe_type.ndim
members = [
("meminfo", types.MemInfoPointer(fe_type.dtype)),
("parent", types.pyobject),
("nitems", types.intp),
("itemsize", types.intp),
("data", types.CPointer(fe_type.dtype)),
("shape", types.UniTuple(types.intp, ndim)),
("strides", types.UniTuple(types.intp, ndim)),
]
super().__init__(dmm, fe_type, members)
register_default(DeviceArrayType)(ArrayModel)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@shwina shwina requested a review from a team as a code owner December 12, 2025 13:57
@shwina shwina requested a review from NaderAlAwar December 12, 2025 13:57
@shwina
Copy link
Contributor Author

shwina commented Dec 12, 2025

/ok to test 12f1899

@shwina
Copy link
Contributor Author

shwina commented Dec 12, 2025

/ok to test 8d6faa0

@github-actions

This comment has been minimized.


interface = pyval.__cuda_array_interface__

# hold on to the device-array-like object
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmarkall - added your code snippet for keeping references to the arrays here.

@github-actions
Copy link
Contributor

😬 CI Workflow Results

🟥 Finished in 1h 06m: Pass: 50%/48 | Total: 8h 30m | Max: 42m 18s

See results here.

Comment on lines +82 to +102
itemsize = np.dtype(interface["typestr"]).itemsize
# Check C-contiguous
c_strides = []
stride = itemsize
for i in range(ndim - 1, -1, -1):
c_strides.insert(0, stride)
stride *= shape[i]

if tuple(strides) == tuple(c_strides):
layout = "C"
else:
# Check F-contiguous
f_strides = []
stride = itemsize
for i in range(ndim):
f_strides.append(stride)
stride *= shape[i]
if tuple(strides) == tuple(f_strides):
layout = "F"
else:
layout = "A"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps one can do a quick rejection test to avoid reconstruction of entire c_strides or f_strides.
If neither last nor first element of strides sequence equals 1, we can set layout = "A".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

5 participants