Skip to content

Commit 0624001

Browse files
authored
Expose is_supported_dtype to the public interface (#150)
Also take this opportunity to clean up a naming inconsistency; NumPy types are "dtypes", core types are "types".
1 parent 6dd4320 commit 0624001

File tree

9 files changed

+33
-19
lines changed

9 files changed

+33
-19
lines changed

Diff for: cunumeric/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ._array.util import maybe_convert_to_np_ndarray
3232
from ._module import *
3333
from ._ufunc import *
34+
from ._utils.array import is_supported_dtype
3435
from ._utils.coverage import clone_module
3536

3637
clone_module(_np, globals(), maybe_convert_to_np_ndarray)

Diff for: cunumeric/_array/array.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131

3232
from .. import _ufunc
33-
from .._utils.array import calculate_volume, to_core_dtype
33+
from .._utils.array import calculate_volume, to_core_type
3434
from .._utils.coverage import FALLBACK_WARNING, clone_class, is_implemented
3535
from .._utils.linalg import dot_modes
3636
from .._utils.structure import deep_apply
@@ -128,7 +128,7 @@ def __init__(
128128
for inp in inputs
129129
if isinstance(inp, ndarray)
130130
]
131-
core_dtype = to_core_dtype(dtype)
131+
core_dtype = to_core_type(dtype)
132132
self._thunk = runtime.create_empty_thunk(
133133
sanitized_shape, core_dtype, inputs
134134
)
@@ -660,7 +660,7 @@ def __contains__(self, item: Any) -> ndarray:
660660
args = (np.array(item, dtype=self.dtype),)
661661
if args[0].size != 1:
662662
raise ValueError("contains needs scalar item")
663-
core_dtype = to_core_dtype(self.dtype)
663+
core_dtype = to_core_type(self.dtype)
664664
return perform_unary_reduction(
665665
UnaryRedCode.CONTAINS,
666666
self,
@@ -1975,7 +1975,7 @@ def clip(
19751975
return convert_to_cunumeric_ndarray(
19761976
self.__array__().clip(args[0], args[1])
19771977
)
1978-
core_dtype = to_core_dtype(self.dtype)
1978+
core_dtype = to_core_type(self.dtype)
19791979
extra_args = (Scalar(min, core_dtype), Scalar(max, core_dtype))
19801980
return perform_unary_op(
19811981
UnaryOpCode.CLIP, self, out=out, extra_args=extra_args
@@ -2971,7 +2971,7 @@ def var(
29712971
# FIXME(wonchanl): the following code blocks on mu to convert
29722972
# it to a Scalar object. We need to get rid of this blocking by
29732973
# allowing the extra arguments to be Legate stores
2974-
args=(Scalar(mu.__array__(), to_core_dtype(self.dtype)),),
2974+
args=(Scalar(mu.__array__(), to_core_type(self.dtype)),),
29752975
)
29762976
else:
29772977
# TODO(https://github.com/nv-legate/cunumeric/issues/591)

Diff for: cunumeric/_thunk/deferred.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
normalize_axis_tuple,
5050
)
5151

52-
from .._utils.array import is_advanced_indexing, to_core_dtype
52+
from .._utils.array import is_advanced_indexing, to_core_type
5353
from ..config import (
5454
BinaryOpCode,
5555
BitGeneratorDistribution,
@@ -1701,7 +1701,7 @@ def select(
17011701
c_arr = c._broadcast(self.shape)
17021702
task.add_input(c_arr)
17031703
task.add_alignment(c_arr, out_arr)
1704-
task.add_scalar_arg(default, to_core_dtype(default.dtype))
1704+
task.add_scalar_arg(default, to_core_type(default.dtype))
17051705
task.execute()
17061706

17071707
# Create or extract a diagonal from a matrix

Diff for: cunumeric/_utils/array.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,24 @@
4040
}
4141

4242

43-
def is_supported_type(dtype: str | np.dtype[Any]) -> bool:
43+
def is_supported_dtype(dtype: str | np.dtype[Any]) -> bool:
44+
"""
45+
Whether a NumPy dtype is supported by cuNumeric
46+
47+
Parameters
48+
----------
49+
dtype : data-type
50+
The dtype to query
51+
52+
Returns
53+
-------
54+
res : bool
55+
True if `dtype` is a supported dtype
56+
"""
4457
return np.dtype(dtype) in SUPPORTED_DTYPES
4558

4659

47-
def to_core_dtype(dtype: str | np.dtype[Any]) -> ty.Type:
60+
def to_core_type(dtype: str | np.dtype[Any]) -> ty.Type:
4861
core_dtype = SUPPORTED_DTYPES.get(np.dtype(dtype))
4962
if core_dtype is None:
5063
raise TypeError(f"cuNumeric does not support dtype={dtype}")

Diff for: cunumeric/runtime.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from legate.core import LEGATE_MAX_DIM, Scalar, TaskTarget, get_legate_runtime
2424
from legate.settings import settings as legate_settings
2525

26-
from ._utils.array import calculate_volume, is_supported_type, to_core_dtype
26+
from ._utils.array import calculate_volume, is_supported_dtype, to_core_type
2727
from ._utils.stack import find_last_user_stacklevel
2828
from .config import (
2929
BitGeneratorOperation,
@@ -60,7 +60,7 @@ def thunk_from_scalar(
6060
from ._thunk.deferred import DeferredArray
6161

6262
store = legate_runtime.create_store_from_scalar(
63-
Scalar(bytes, to_core_dtype(dtype)),
63+
Scalar(bytes, to_core_type(dtype)),
6464
shape=shape,
6565
)
6666
return DeferredArray(store)
@@ -377,7 +377,7 @@ def find_or_create_array_thunk(
377377
from ._thunk.deferred import DeferredArray
378378

379379
assert isinstance(array, np.ndarray)
380-
if not is_supported_type(array.dtype):
380+
if not is_supported_dtype(array.dtype):
381381
raise TypeError(f"cuNumeric does not support dtype={array.dtype}")
382382

383383
# We have to be really careful here to handle the case of
@@ -429,7 +429,7 @@ def find_or_create_array_thunk(
429429
# This is not a scalar so make a field.
430430
# We won't try to cache these bigger arrays.
431431
store = legate_runtime.create_store_from_buffer(
432-
to_core_dtype(array.dtype),
432+
to_core_type(array.dtype),
433433
array.shape,
434434
array.copy() if transfer == TransferType.MAKE_COPY else array,
435435
# This argument should really be called "donate"

Diff for: tests/integration/test_argsort.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_structured_array_order(self):
9898
# if self.deferred is None:
9999
# if self.parent is None:
100100
#
101-
# > assert self.runtime.is_supported_type(self.array.dtype)
101+
# > assert self.runtime.is_supported_dtype(self.array.dtype)
102102
# E
103103
# AssertionError
104104
#

Diff for: tests/integration/test_prod.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_dtype_complex(self, dtype):
222222
# allclose hits assertion error:
223223
# File "/legate/cunumeric/cunumeric/eager.py", line 293,
224224
# in to_deferred_array
225-
# assert self.runtime.is_supported_type(self.array.dtype)
225+
# assert self.runtime.is_supported_dtype(self.array.dtype)
226226
# AssertionError
227227
assert allclose(out_np, out_num)
228228

Diff for: tests/integration/test_searchsorted.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_val_none(self):
8383
# cuNumeric raises AssertionError
8484
# if self.deferred is None:
8585
# if self.parent is None:
86-
# > assert self.runtime.is_supported_type
86+
# > assert self.runtime.is_supported_dtype
8787
# (self.array.dtype)
8888
# E AssertionError
8989
# cunumeric/cunumeric/eager.py:to_deferred_array()

Diff for: tests/unit/cunumeric/test_utils_array.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ class Test_is_supported_dtype:
7575
@pytest.mark.parametrize("value", ["foo", 10, 10.2, (), set()])
7676
def test_type_bad(self, value) -> None:
7777
with pytest.raises(TypeError):
78-
m.to_core_dtype(value)
78+
m.to_core_type(value)
7979

8080
@pytest.mark.parametrize("value", EXPECTED_SUPPORTED_DTYPES)
8181
def test_supported(self, value) -> None:
82-
m.to_core_dtype(value)
82+
m.to_core_type(value)
8383

8484
# This is just a representative sample, not exhasutive
8585
@pytest.mark.parametrize("value", [np.float128, np.datetime64, [], {}])
8686
def test_unsupported(self, value) -> None:
8787
with pytest.raises(TypeError):
88-
m.to_core_dtype(value)
88+
m.to_core_type(value)
8989

9090

9191
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)