@@ -5499,46 +5499,78 @@ def load_dispatch_func(input_types: Mapping[str, type], return_type: Any, args:
54995499)
55005500
55015501
5502+ SUPPORTED_ATOMIC_TYPES = (
5503+ warp .int32 ,
5504+ warp .int64 ,
5505+ warp .uint32 ,
5506+ warp .uint64 ,
5507+ warp .float32 ,
5508+ warp .float64 ,
5509+ )
5510+
5511+
55025512def atomic_op_constraint (arg_types : Mapping [str , Any ]):
55035513 idx_types = tuple (arg_types [x ] for x in "ijkl" if arg_types .get (x , None ) is not None )
55045514 return all (types_equal (idx_types [0 ], t ) for t in idx_types [1 :]) and arg_types ["arr" ].ndim == len (idx_types )
55055515
55065516
5507- def atomic_op_value_func (arg_types : Mapping [str , type ], arg_values : Mapping [str , Any ]):
5508- if arg_types is None :
5509- return Any
5517+ def create_atomic_op_value_func (op : str ):
5518+ def fn (arg_types : Mapping [str , type ], arg_values : Mapping [str , Any ]):
5519+ if arg_types is None :
5520+ return Any
55105521
5511- arr_type = arg_types ["arr" ]
5512- value_type = arg_types ["value" ]
5513- idx_types = tuple (arg_types [x ] for x in "ijkl" if arg_types .get (x , None ) is not None )
5522+ arr_type = arg_types ["arr" ]
5523+ value_type = arg_types ["value" ]
5524+ idx_types = tuple (arg_types [x ] for x in "ijkl" if arg_types .get (x , None ) is not None )
55145525
5515- if not is_array (arr_type ):
5516- raise RuntimeError ("atomic () first argument must be an array" )
5526+ if not is_array (arr_type ):
5527+ raise RuntimeError (f"atomic_ { op } () first argument must be an array" )
55175528
5518- idx_count = len (idx_types )
5529+ idx_count = len (idx_types )
55195530
5520- if idx_count < arr_type .ndim :
5521- raise RuntimeError (
5522- "Num indices < num dimensions for atomic , this is a codegen error, should have generated a view instead"
5523- )
5531+ if idx_count < arr_type .ndim :
5532+ raise RuntimeError (
5533+ f "Num indices < num dimensions for atomic_ { op } () , this is a codegen error, should have generated a view instead"
5534+ )
55245535
5525- if idx_count > arr_type .ndim :
5526- raise RuntimeError (
5527- f"Num indices > num dimensions for atomic , received { idx_count } , but array only has { arr_type .ndim } "
5528- )
5536+ if idx_count > arr_type .ndim :
5537+ raise RuntimeError (
5538+ f"Num indices > num dimensions for atomic_ { op } () , received { idx_count } , but array only has { arr_type .ndim } "
5539+ )
55295540
5530- # check index types
5531- for t in idx_types :
5532- if not type_is_int (t ):
5533- raise RuntimeError (f"atomic() index arguments must be of integer type, got index of type { type_repr (t )} " )
5541+ # check index types
5542+ for t in idx_types :
5543+ if not type_is_int (t ):
5544+ raise RuntimeError (
5545+ f"atomic_{ op } () index arguments must be of integer type, got index of type { type_repr (t )} "
5546+ )
55345547
5535- # check value type
5536- if not types_equal (arr_type .dtype , value_type ):
5537- raise RuntimeError (
5538- f"atomic () value argument type ({ type_repr (value_type )} ) must be of the same type as the array ({ type_repr (arr_type .dtype )} )"
5539- )
5548+ # check value type
5549+ if not types_equal (arr_type .dtype , value_type ):
5550+ raise RuntimeError (
5551+ f"atomic_ { op } () value argument type ({ type_repr (value_type )} ) must be of the same type as the array ({ type_repr (arr_type .dtype )} )"
5552+ )
55405553
5541- return arr_type .dtype
5554+ scalar_type = getattr (arr_type .dtype , "_wp_scalar_type_" , arr_type .dtype )
5555+ if op in ("add" , "sub" ):
5556+ supported_atomic_types = (* SUPPORTED_ATOMIC_TYPES , warp .float16 )
5557+ if not any (types_equal (scalar_type , x , match_generic = True ) for x in supported_atomic_types ):
5558+ raise RuntimeError (
5559+ f"atomic_{ op } () operations only work on arrays with [u]int32, [u]int64, float16, float32, or float64 "
5560+ f"as the underlying scalar types, but got { type_repr (arr_type .dtype )} (with scalar type { type_repr (scalar_type )} )"
5561+ )
5562+ elif op in ("min" , "max" ):
5563+ if not any (types_equal (scalar_type , x , match_generic = True ) for x in SUPPORTED_ATOMIC_TYPES ):
5564+ raise RuntimeError (
5565+ f"atomic_{ op } () operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
5566+ f"as the underlying scalar types, but got { type_repr (arr_type .dtype )} (with scalar type { type_repr (scalar_type )} )"
5567+ )
5568+ else :
5569+ raise NotImplementedError
5570+
5571+ return arr_type .dtype
5572+
5573+ return fn
55425574
55435575
55445576def atomic_op_dispatch_func (input_types : Mapping [str , type ], return_type : Any , args : Mapping [str , Var ]):
@@ -5563,7 +5595,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55635595 hidden = hidden ,
55645596 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "value" : Any },
55655597 constraint = atomic_op_constraint ,
5566- value_func = atomic_op_value_func ,
5598+ value_func = create_atomic_op_value_func ( "add" ) ,
55675599 dispatch_func = atomic_op_dispatch_func ,
55685600 doc = "Atomically add ``value`` onto ``arr[i]`` and return the old value." ,
55695601 group = "Utility" ,
@@ -5574,7 +5606,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55745606 hidden = hidden ,
55755607 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "value" : Any },
55765608 constraint = atomic_op_constraint ,
5577- value_func = atomic_op_value_func ,
5609+ value_func = create_atomic_op_value_func ( "add" ) ,
55785610 dispatch_func = atomic_op_dispatch_func ,
55795611 doc = "Atomically add ``value`` onto ``arr[i,j]`` and return the old value." ,
55805612 group = "Utility" ,
@@ -5585,7 +5617,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55855617 hidden = hidden ,
55865618 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "value" : Any },
55875619 constraint = atomic_op_constraint ,
5588- value_func = atomic_op_value_func ,
5620+ value_func = create_atomic_op_value_func ( "add" ) ,
55895621 dispatch_func = atomic_op_dispatch_func ,
55905622 doc = "Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value." ,
55915623 group = "Utility" ,
@@ -5596,7 +5628,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
55965628 hidden = hidden ,
55975629 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "l" : Int , "value" : Any },
55985630 constraint = atomic_op_constraint ,
5599- value_func = atomic_op_value_func ,
5631+ value_func = create_atomic_op_value_func ( "add" ) ,
56005632 dispatch_func = atomic_op_dispatch_func ,
56015633 doc = "Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value." ,
56025634 group = "Utility" ,
@@ -5608,7 +5640,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56085640 hidden = hidden ,
56095641 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "value" : Any },
56105642 constraint = atomic_op_constraint ,
5611- value_func = atomic_op_value_func ,
5643+ value_func = create_atomic_op_value_func ( "sub" ) ,
56125644 dispatch_func = atomic_op_dispatch_func ,
56135645 doc = "Atomically subtract ``value`` onto ``arr[i]`` and return the old value." ,
56145646 group = "Utility" ,
@@ -5619,7 +5651,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56195651 hidden = hidden ,
56205652 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "value" : Any },
56215653 constraint = atomic_op_constraint ,
5622- value_func = atomic_op_value_func ,
5654+ value_func = create_atomic_op_value_func ( "sub" ) ,
56235655 dispatch_func = atomic_op_dispatch_func ,
56245656 doc = "Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value." ,
56255657 group = "Utility" ,
@@ -5630,7 +5662,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56305662 hidden = hidden ,
56315663 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "value" : Any },
56325664 constraint = atomic_op_constraint ,
5633- value_func = atomic_op_value_func ,
5665+ value_func = create_atomic_op_value_func ( "sub" ) ,
56345666 dispatch_func = atomic_op_dispatch_func ,
56355667 doc = "Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value." ,
56365668 group = "Utility" ,
@@ -5641,7 +5673,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56415673 hidden = hidden ,
56425674 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "l" : Int , "value" : Any },
56435675 constraint = atomic_op_constraint ,
5644- value_func = atomic_op_value_func ,
5676+ value_func = create_atomic_op_value_func ( "sub" ) ,
56455677 dispatch_func = atomic_op_dispatch_func ,
56465678 doc = "Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value." ,
56475679 group = "Utility" ,
@@ -5653,7 +5685,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56535685 hidden = hidden ,
56545686 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "value" : Any },
56555687 constraint = atomic_op_constraint ,
5656- value_func = atomic_op_value_func ,
5688+ value_func = create_atomic_op_value_func ( "min" ) ,
56575689 dispatch_func = atomic_op_dispatch_func ,
56585690 doc = """Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
56595691
@@ -5666,7 +5698,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56665698 hidden = hidden ,
56675699 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "value" : Any },
56685700 constraint = atomic_op_constraint ,
5669- value_func = atomic_op_value_func ,
5701+ value_func = create_atomic_op_value_func ( "min" ) ,
56705702 dispatch_func = atomic_op_dispatch_func ,
56715703 doc = """Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
56725704
@@ -5679,7 +5711,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56795711 hidden = hidden ,
56805712 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "value" : Any },
56815713 constraint = atomic_op_constraint ,
5682- value_func = atomic_op_value_func ,
5714+ value_func = create_atomic_op_value_func ( "min" ) ,
56835715 dispatch_func = atomic_op_dispatch_func ,
56845716 doc = """Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
56855717
@@ -5692,7 +5724,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
56925724 hidden = hidden ,
56935725 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "l" : Int , "value" : Any },
56945726 constraint = atomic_op_constraint ,
5695- value_func = atomic_op_value_func ,
5727+ value_func = create_atomic_op_value_func ( "min" ) ,
56965728 dispatch_func = atomic_op_dispatch_func ,
56975729 doc = """Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
56985730
@@ -5706,7 +5738,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57065738 hidden = hidden ,
57075739 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "value" : Any },
57085740 constraint = atomic_op_constraint ,
5709- value_func = atomic_op_value_func ,
5741+ value_func = create_atomic_op_value_func ( "max" ) ,
57105742 dispatch_func = atomic_op_dispatch_func ,
57115743 doc = """Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
57125744
@@ -5719,7 +5751,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57195751 hidden = hidden ,
57205752 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "value" : Any },
57215753 constraint = atomic_op_constraint ,
5722- value_func = atomic_op_value_func ,
5754+ value_func = create_atomic_op_value_func ( "max" ) ,
57235755 dispatch_func = atomic_op_dispatch_func ,
57245756 doc = """Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
57255757
@@ -5732,7 +5764,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57325764 hidden = hidden ,
57335765 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "value" : Any },
57345766 constraint = atomic_op_constraint ,
5735- value_func = atomic_op_value_func ,
5767+ value_func = create_atomic_op_value_func ( "max" ) ,
57365768 dispatch_func = atomic_op_dispatch_func ,
57375769 doc = """Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
57385770
@@ -5745,7 +5777,7 @@ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, a
57455777 hidden = hidden ,
57465778 input_types = {"arr" : array_type (dtype = Any ), "i" : Int , "j" : Int , "k" : Int , "l" : Int , "value" : Any },
57475779 constraint = atomic_op_constraint ,
5748- value_func = atomic_op_value_func ,
5780+ value_func = create_atomic_op_value_func ( "max" ) ,
57495781 dispatch_func = atomic_op_dispatch_func ,
57505782 doc = """Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
57515783
0 commit comments