From b4e94b04e38d55414fcc9af2b337aee418f93f04 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 25 Sep 2024 12:22:18 +0100 Subject: [PATCH] Add new 2023.12 elemwise functions: `clip`, `copysign`, `hypot`, `maximum`, `minimum`, `signbit`. (#583) --- api_status.md | 2 +- cubed/__init__.py | 12 +++++ cubed/array_api/__init__.py | 12 +++++ cubed/array_api/elementwise_functions.py | 64 ++++++++++++++++++++++++ cubed/tests/test_array_api.py | 20 ++++++++ 5 files changed, 109 insertions(+), 1 deletion(-) diff --git a/api_status.md b/api_status.md index de6551ba..2955ea0e 100644 --- a/api_status.md +++ b/api_status.md @@ -41,7 +41,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | Data Types | `bool`, `int8`, ... | :white_check_mark: | | | | Elementwise Functions | `add` | :white_check_mark: | | Example of a binary function | | | `negative` | :white_check_mark: | | Example of a unary function | -| | _others_ | :white_check_mark: | | Except 2023.12 functions in [#438](https://github.com/cubed-dev/cubed/issues/438) | +| | _others_ | :white_check_mark: | | | | Indexing | Single-axis | :white_check_mark: | | | | | Multi-axis | :white_check_mark: | | | | | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) | diff --git a/cubed/__init__.py b/cubed/__init__.py index 790d54d4..3cd2e6b4 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -153,7 +153,9 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, + copysign, cos, cosh, divide, @@ -164,6 +166,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -179,6 +182,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -188,6 +193,7 @@ remainder, round, sign, + signbit, sin, sinh, sqrt, @@ -215,7 +221,9 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "conj", + "copysign", "cos", "cosh", "divide", @@ -226,6 +234,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", @@ -241,6 +250,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", @@ -250,6 +261,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "sqrt", diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 993c9bf3..ea0a8c2b 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -101,7 +101,9 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, + copysign, cos, cosh, divide, @@ -112,6 +114,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -127,6 +130,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -136,6 +141,7 @@ remainder, round, sign, + signbit, sin, sinh, sqrt, @@ -163,7 +169,9 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "conj", + "copysign", "cos", "cosh", "divide", @@ -174,6 +182,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", @@ -189,6 +198,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", @@ -198,6 +209,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "sqrt", diff --git a/cubed/array_api/elementwise_functions.py b/cubed/array_api/elementwise_functions.py index 7d8f0086..217a6788 100644 --- a/cubed/array_api/elementwise_functions.py +++ b/cubed/array_api/elementwise_functions.py @@ -1,3 +1,5 @@ +from cubed.array_api.array_object import Array +from cubed.array_api.creation_functions import asarray from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import ( _boolean_dtypes, @@ -131,12 +133,50 @@ def ceil(x, /): return elemwise(nxp.ceil, x, dtype=x.dtype) +def clip(x, /, min=None, max=None): + if ( + x.dtype not in _real_numeric_dtypes + or isinstance(min, Array) + and min.dtype not in _real_numeric_dtypes + or isinstance(max, Array) + and max.dtype not in _real_numeric_dtypes + ): + raise TypeError("Only real numeric dtypes are allowed in clip") + if not isinstance(min, (int, float, Array, type(None))): + raise TypeError("min must be an None, int, float, or an array") + if not isinstance(max, (int, float, Array, type(None))): + raise TypeError("max must be an None, int, float, or an array") + + if min is max is None: + return x + elif min is not None and max is None: + min = asarray(min, spec=x.spec) + return elemwise(nxp.clip, x, min, dtype=x.dtype) + elif min is None and max is not None: + + def clip_max(x_, max_): + return nxp.clip(x_, max=max_) + + max = asarray(max, spec=x.spec) + return elemwise(clip_max, x, max, dtype=x.dtype) + else: # min is not None and max is not None + min = asarray(min, spec=x.spec) + max = asarray(max, spec=x.spec) + return elemwise(nxp.clip, x, min, max, dtype=x.dtype) + + def conj(x, /): if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in conj") return elemwise(nxp.conj, x, dtype=x.dtype) +def copysign(x1, x2, /): + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in copysign") + return elemwise(nxp.copysign, x1, x2, dtype=result_type(x1, x2)) + + def cos(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cos") @@ -194,6 +234,12 @@ def greater_equal(x1, x2, /): return elemwise(nxp.greater_equal, x1, x2, dtype=nxp.bool) +def hypot(x1, x2, /): + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in hypot") + return elemwise(nxp.hypot, x1, x2, dtype=result_type(x1, x2)) + + def imag(x, /): if x.dtype == complex64: dtype = float32 @@ -284,6 +330,18 @@ def logical_xor(x1, x2, /): return elemwise(nxp.logical_xor, x1, x2, dtype=nxp.bool) +def maximum(x1, x2, /): + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in maximum") + return elemwise(nxp.maximum, x1, x2, dtype=result_type(x1, x2)) + + +def minimum(x1, x2, /): + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in minimum") + return elemwise(nxp.minimum, x1, x2, dtype=result_type(x1, x2)) + + def multiply(x1, x2, /): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in multiply") @@ -340,6 +398,12 @@ def sign(x, /): return elemwise(nxp.sign, x, dtype=x.dtype) +def signbit(x, /): + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in signbit") + return elemwise(nxp.signbit, x, dtype=nxp.bool) + + def sin(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sin") diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 21c10d4a..b7764caa 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -194,6 +194,26 @@ def test_add_different_chunks_fail(spec, executor): assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,))) +@pytest.mark.parametrize( + "min, max", + [ + (None, None), + (4, None), + (None, 7), + (4, 7), + (0, 10), + ], +) +def test_clip(spec, min, max): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + npa = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + b = xp.clip(a, min, max) + if min is max is None: + assert b is a + else: + assert_array_equal(b.compute(), np.clip(npa, min, max)) + + def test_equal(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) b = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)