Skip to content

Commit

Permalink
Add new 2023.12 elemwise functions: clip, copysign, hypot, `max…
Browse files Browse the repository at this point in the history
…imum`, `minimum`, `signbit`. (#583)
  • Loading branch information
tomwhite authored Sep 25, 2024
1 parent 73bbf5c commit b4e94b0
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 1 deletion.
2 changes: 1 addition & 1 deletion api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| Data Types | `bool`, `int8`, ... | :white_check_mark: | | |
| Elementwise Functions | `add` | :white_check_mark: | | Example of a binary function |
| | `negative` | :white_check_mark: | | Example of a unary function |
| | _others_ | :white_check_mark: | | Except 2023.12 functions in [#438](https://github.com/cubed-dev/cubed/issues/438) |
| | _others_ | :white_check_mark: | | |
| Indexing | Single-axis | :white_check_mark: | | |
| | Multi-axis | :white_check_mark: | | |
| | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) |
Expand Down
12 changes: 12 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@
bitwise_right_shift,
bitwise_xor,
ceil,
clip,
conj,
copysign,
cos,
cosh,
divide,
Expand All @@ -164,6 +166,7 @@
floor_divide,
greater,
greater_equal,
hypot,
imag,
isfinite,
isinf,
Expand All @@ -179,6 +182,8 @@
logical_not,
logical_or,
logical_xor,
maximum,
minimum,
multiply,
negative,
not_equal,
Expand All @@ -188,6 +193,7 @@
remainder,
round,
sign,
signbit,
sin,
sinh,
sqrt,
Expand Down Expand Up @@ -215,7 +221,9 @@
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"clip",
"conj",
"copysign",
"cos",
"cosh",
"divide",
Expand All @@ -226,6 +234,7 @@
"floor_divide",
"greater",
"greater_equal",
"hypot",
"imag",
"isfinite",
"isinf",
Expand All @@ -241,6 +250,8 @@
"logical_not",
"logical_or",
"logical_xor",
"maximum",
"minimum",
"multiply",
"negative",
"not_equal",
Expand All @@ -250,6 +261,7 @@
"remainder",
"round",
"sign",
"signbit",
"sin",
"sinh",
"sqrt",
Expand Down
12 changes: 12 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@
bitwise_right_shift,
bitwise_xor,
ceil,
clip,
conj,
copysign,
cos,
cosh,
divide,
Expand All @@ -112,6 +114,7 @@
floor_divide,
greater,
greater_equal,
hypot,
imag,
isfinite,
isinf,
Expand All @@ -127,6 +130,8 @@
logical_not,
logical_or,
logical_xor,
maximum,
minimum,
multiply,
negative,
not_equal,
Expand All @@ -136,6 +141,7 @@
remainder,
round,
sign,
signbit,
sin,
sinh,
sqrt,
Expand Down Expand Up @@ -163,7 +169,9 @@
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"clip",
"conj",
"copysign",
"cos",
"cosh",
"divide",
Expand All @@ -174,6 +182,7 @@
"floor_divide",
"greater",
"greater_equal",
"hypot",
"imag",
"isfinite",
"isinf",
Expand All @@ -189,6 +198,8 @@
"logical_not",
"logical_or",
"logical_xor",
"maximum",
"minimum",
"multiply",
"negative",
"not_equal",
Expand All @@ -198,6 +209,7 @@
"remainder",
"round",
"sign",
"signbit",
"sin",
"sinh",
"sqrt",
Expand Down
64 changes: 64 additions & 0 deletions cubed/array_api/elementwise_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from cubed.array_api.array_object import Array
from cubed.array_api.creation_functions import asarray
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import (
_boolean_dtypes,
Expand Down Expand Up @@ -131,12 +133,50 @@ def ceil(x, /):
return elemwise(nxp.ceil, x, dtype=x.dtype)


def clip(x, /, min=None, max=None):
if (
x.dtype not in _real_numeric_dtypes
or isinstance(min, Array)
and min.dtype not in _real_numeric_dtypes
or isinstance(max, Array)
and max.dtype not in _real_numeric_dtypes
):
raise TypeError("Only real numeric dtypes are allowed in clip")
if not isinstance(min, (int, float, Array, type(None))):
raise TypeError("min must be an None, int, float, or an array")
if not isinstance(max, (int, float, Array, type(None))):
raise TypeError("max must be an None, int, float, or an array")

if min is max is None:
return x
elif min is not None and max is None:
min = asarray(min, spec=x.spec)
return elemwise(nxp.clip, x, min, dtype=x.dtype)
elif min is None and max is not None:

def clip_max(x_, max_):
return nxp.clip(x_, max=max_)

max = asarray(max, spec=x.spec)
return elemwise(clip_max, x, max, dtype=x.dtype)
else: # min is not None and max is not None
min = asarray(min, spec=x.spec)
max = asarray(max, spec=x.spec)
return elemwise(nxp.clip, x, min, max, dtype=x.dtype)


def conj(x, /):
if x.dtype not in _complex_floating_dtypes:
raise TypeError("Only complex floating-point dtypes are allowed in conj")
return elemwise(nxp.conj, x, dtype=x.dtype)


def copysign(x1, x2, /):
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in copysign")
return elemwise(nxp.copysign, x1, x2, dtype=result_type(x1, x2))


def cos(x, /):
if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in cos")
Expand Down Expand Up @@ -194,6 +234,12 @@ def greater_equal(x1, x2, /):
return elemwise(nxp.greater_equal, x1, x2, dtype=nxp.bool)


def hypot(x1, x2, /):
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in hypot")
return elemwise(nxp.hypot, x1, x2, dtype=result_type(x1, x2))


def imag(x, /):
if x.dtype == complex64:
dtype = float32
Expand Down Expand Up @@ -284,6 +330,18 @@ def logical_xor(x1, x2, /):
return elemwise(nxp.logical_xor, x1, x2, dtype=nxp.bool)


def maximum(x1, x2, /):
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in maximum")
return elemwise(nxp.maximum, x1, x2, dtype=result_type(x1, x2))


def minimum(x1, x2, /):
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in minimum")
return elemwise(nxp.minimum, x1, x2, dtype=result_type(x1, x2))


def multiply(x1, x2, /):
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in multiply")
Expand Down Expand Up @@ -340,6 +398,12 @@ def sign(x, /):
return elemwise(nxp.sign, x, dtype=x.dtype)


def signbit(x, /):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in signbit")
return elemwise(nxp.signbit, x, dtype=nxp.bool)


def sin(x, /):
if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in sin")
Expand Down
20 changes: 20 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,26 @@ def test_add_different_chunks_fail(spec, executor):
assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,)))


@pytest.mark.parametrize(
"min, max",
[
(None, None),
(4, None),
(None, 7),
(4, 7),
(0, 10),
],
)
def test_clip(spec, min, max):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
npa = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = xp.clip(a, min, max)
if min is max is None:
assert b is a
else:
assert_array_equal(b.compute(), np.clip(npa, min, max))


def test_equal(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
Expand Down

0 comments on commit b4e94b0

Please sign in to comment.