Skip to content

Commit f940332

Browse files
Merge branch 'ccrouzet/gh-733-expand-int-indexing' into 'main'
Fix Indexing of `wp.atomic_*` With Unsigned Integers (GH-733) See merge request omniverse/warp!1315
2 parents 4cef86a + 07b18e5 commit f940332

5 files changed

Lines changed: 357 additions & 118 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
- Fix `UsdRenderer.render_points()` erroring out when passed 4 points or less
103103
([GH-708](https://github.com/NVIDIA/warp/issues/708)).
104104
- Fix garbage collection issues with JAX FFI callbacks ([GH-711](https://github.com/NVIDIA/warp/pull/711)).
105+
- Fix `wp.atomic_*()` built-ins not working with some types
106+
([GH-733](https://github.com/NVIDIA/warp/issues/733)).
105107

106108
## [1.7.1] - 2025-04-30
107109

warp/builtins.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5499,46 +5499,78 @@ def load_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
54995499
)
55005500

55015501

5502+
SUPPORTED_ATOMIC_TYPES = (
5503+
warp.int32,
5504+
warp.int64,
5505+
warp.uint32,
5506+
warp.uint64,
5507+
warp.float32,
5508+
warp.float64,
5509+
)
5510+
5511+
55025512
def atomic_op_constraint(arg_types: Mapping[str, Any]):
55035513
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
55045514
return all(types_equal(idx_types[0], t) for t in idx_types[1:]) and arg_types["arr"].ndim == len(idx_types)
55055515

55065516

5507-
def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5508-
if arg_types is None:
5509-
return Any
5517+
def create_atomic_op_value_func(op: str):
5518+
def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5519+
if arg_types is None:
5520+
return Any
55105521

5511-
arr_type = arg_types["arr"]
5512-
value_type = arg_types["value"]
5513-
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
5522+
arr_type = arg_types["arr"]
5523+
value_type = arg_types["value"]
5524+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
55145525

5515-
if not is_array(arr_type):
5516-
raise RuntimeError("atomic() first argument must be an array")
5526+
if not is_array(arr_type):
5527+
raise RuntimeError(f"atomic_{op}() first argument must be an array")
55175528

5518-
idx_count = len(idx_types)
5529+
idx_count = len(idx_types)
55195530

5520-
if idx_count < arr_type.ndim:
5521-
raise RuntimeError(
5522-
"Num indices < num dimensions for atomic, this is a codegen error, should have generated a view instead"
5523-
)
5531+
if idx_count < arr_type.ndim:
5532+
raise RuntimeError(
5533+
f"Num indices < num dimensions for atomic_{op}(), this is a codegen error, should have generated a view instead"
5534+
)
55245535

5525-
if idx_count > arr_type.ndim:
5526-
raise RuntimeError(
5527-
f"Num indices > num dimensions for atomic, received {idx_count}, but array only has {arr_type.ndim}"
5528-
)
5536+
if idx_count > arr_type.ndim:
5537+
raise RuntimeError(
5538+
f"Num indices > num dimensions for atomic_{op}(), received {idx_count}, but array only has {arr_type.ndim}"
5539+
)
55295540

5530-
# check index types
5531-
for t in idx_types:
5532-
if not type_is_int(t):
5533-
raise RuntimeError(f"atomic() index arguments must be of integer type, got index of type {type_repr(t)}")
5541+
# check index types
5542+
for t in idx_types:
5543+
if not type_is_int(t):
5544+
raise RuntimeError(
5545+
f"atomic_{op}() index arguments must be of integer type, got index of type {type_repr(t)}"
5546+
)
55345547

5535-
# check value type
5536-
if not types_equal(arr_type.dtype, value_type):
5537-
raise RuntimeError(
5538-
f"atomic() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
5539-
)
5548+
# check value type
5549+
if not types_equal(arr_type.dtype, value_type):
5550+
raise RuntimeError(
5551+
f"atomic_{op}() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
5552+
)
55405553

5541-
return arr_type.dtype
5554+
scalar_type = getattr(arr_type.dtype, "_wp_scalar_type_", arr_type.dtype)
5555+
if op in ("add", "sub"):
5556+
supported_atomic_types = (*SUPPORTED_ATOMIC_TYPES, warp.float16)
5557+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
5558+
raise RuntimeError(
5559+
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float16, float32, or float64 "
5560+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
5561+
)
5562+
elif op in ("min", "max"):
5563+
if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
5564+
raise RuntimeError(
5565+
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
5566+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
5567+
)
5568+
else:
5569+
raise NotImplementedError
5570+
5571+
return arr_type.dtype
5572+
5573+
return fn
55425574

55435575

55445576
def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -5563,7 +5595,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55635595
hidden=hidden,
55645596
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
55655597
constraint=atomic_op_constraint,
5566-
value_func=atomic_op_value_func,
5598+
value_func=create_atomic_op_value_func("add"),
55675599
dispatch_func=atomic_op_dispatch_func,
55685600
doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
55695601
group="Utility",
@@ -5574,7 +5606,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55745606
hidden=hidden,
55755607
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
55765608
constraint=atomic_op_constraint,
5577-
value_func=atomic_op_value_func,
5609+
value_func=create_atomic_op_value_func("add"),
55785610
dispatch_func=atomic_op_dispatch_func,
55795611
doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
55805612
group="Utility",
@@ -5585,7 +5617,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55855617
hidden=hidden,
55865618
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
55875619
constraint=atomic_op_constraint,
5588-
value_func=atomic_op_value_func,
5620+
value_func=create_atomic_op_value_func("add"),
55895621
dispatch_func=atomic_op_dispatch_func,
55905622
doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
55915623
group="Utility",
@@ -5596,7 +5628,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55965628
hidden=hidden,
55975629
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
55985630
constraint=atomic_op_constraint,
5599-
value_func=atomic_op_value_func,
5631+
value_func=create_atomic_op_value_func("add"),
56005632
dispatch_func=atomic_op_dispatch_func,
56015633
doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
56025634
group="Utility",
@@ -5608,7 +5640,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56085640
hidden=hidden,
56095641
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
56105642
constraint=atomic_op_constraint,
5611-
value_func=atomic_op_value_func,
5643+
value_func=create_atomic_op_value_func("sub"),
56125644
dispatch_func=atomic_op_dispatch_func,
56135645
doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
56145646
group="Utility",
@@ -5619,7 +5651,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56195651
hidden=hidden,
56205652
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
56215653
constraint=atomic_op_constraint,
5622-
value_func=atomic_op_value_func,
5654+
value_func=create_atomic_op_value_func("sub"),
56235655
dispatch_func=atomic_op_dispatch_func,
56245656
doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
56255657
group="Utility",
@@ -5630,7 +5662,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56305662
hidden=hidden,
56315663
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
56325664
constraint=atomic_op_constraint,
5633-
value_func=atomic_op_value_func,
5665+
value_func=create_atomic_op_value_func("sub"),
56345666
dispatch_func=atomic_op_dispatch_func,
56355667
doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
56365668
group="Utility",
@@ -5641,7 +5673,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56415673
hidden=hidden,
56425674
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
56435675
constraint=atomic_op_constraint,
5644-
value_func=atomic_op_value_func,
5676+
value_func=create_atomic_op_value_func("sub"),
56455677
dispatch_func=atomic_op_dispatch_func,
56465678
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
56475679
group="Utility",
@@ -5653,7 +5685,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56535685
hidden=hidden,
56545686
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
56555687
constraint=atomic_op_constraint,
5656-
value_func=atomic_op_value_func,
5688+
value_func=create_atomic_op_value_func("min"),
56575689
dispatch_func=atomic_op_dispatch_func,
56585690
doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
56595691
@@ -5666,7 +5698,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56665698
hidden=hidden,
56675699
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
56685700
constraint=atomic_op_constraint,
5669-
value_func=atomic_op_value_func,
5701+
value_func=create_atomic_op_value_func("min"),
56705702
dispatch_func=atomic_op_dispatch_func,
56715703
doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
56725704
@@ -5679,7 +5711,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56795711
hidden=hidden,
56805712
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
56815713
constraint=atomic_op_constraint,
5682-
value_func=atomic_op_value_func,
5714+
value_func=create_atomic_op_value_func("min"),
56835715
dispatch_func=atomic_op_dispatch_func,
56845716
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
56855717
@@ -5692,7 +5724,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56925724
hidden=hidden,
56935725
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
56945726
constraint=atomic_op_constraint,
5695-
value_func=atomic_op_value_func,
5727+
value_func=create_atomic_op_value_func("min"),
56965728
dispatch_func=atomic_op_dispatch_func,
56975729
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
56985730
@@ -5706,7 +5738,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57065738
hidden=hidden,
57075739
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
57085740
constraint=atomic_op_constraint,
5709-
value_func=atomic_op_value_func,
5741+
value_func=create_atomic_op_value_func("max"),
57105742
dispatch_func=atomic_op_dispatch_func,
57115743
doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
57125744
@@ -5719,7 +5751,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57195751
hidden=hidden,
57205752
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
57215753
constraint=atomic_op_constraint,
5722-
value_func=atomic_op_value_func,
5754+
value_func=create_atomic_op_value_func("max"),
57235755
dispatch_func=atomic_op_dispatch_func,
57245756
doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
57255757
@@ -5732,7 +5764,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57325764
hidden=hidden,
57335765
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
57345766
constraint=atomic_op_constraint,
5735-
value_func=atomic_op_value_func,
5767+
value_func=create_atomic_op_value_func("max"),
57365768
dispatch_func=atomic_op_dispatch_func,
57375769
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
57385770
@@ -5745,7 +5777,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57455777
hidden=hidden,
57465778
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
57475779
constraint=atomic_op_constraint,
5748-
value_func=atomic_op_value_func,
5780+
value_func=create_atomic_op_value_func("max"),
57495781
dispatch_func=atomic_op_dispatch_func,
57505782
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
57515783

0 commit comments

Comments
 (0)