Skip to content

Commit 6b8172e

Browse files
authored
fix: address missing complex type in applicable function signatures
PR-URL: #905 Ref: #862
1 parent 8b0852b commit 6b8172e

File tree

3 files changed

+41
-35
lines changed

3 files changed

+41
-35
lines changed

src/array_api_stubs/_2022_12/array_object.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ def __abs__(self: array, /) -> array:
146146
Added complex data type support.
147147
"""
148148

149-
def __add__(self: array, other: Union[int, float, array], /) -> array:
149+
def __add__(self: array, other: Union[int, float, complex, array], /) -> array:
150150
"""
151151
Calculates the sum for each element of an array instance with the respective element of the array ``other``.
152152
153153
Parameters
154154
----------
155155
self: array
156156
array instance (augend array). Should have a numeric data type.
157-
other: Union[int, float, array]
157+
other: Union[int, float, complex, array]
158158
addend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
159159
160160
Returns
@@ -374,15 +374,15 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
374374
ROCM = 10
375375
"""
376376

377-
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
377+
def __eq__(self: array, other: Union[int, float, complex, bool, array], /) -> array:
378378
r"""
379379
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
380380
381381
Parameters
382382
----------
383383
self: array
384384
array instance. May have any data type.
385-
other: Union[int, float, bool, array]
385+
other: Union[int, float, complex, bool, array]
386386
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type.
387387
388388
Returns
@@ -393,6 +393,9 @@ def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
393393
394394
.. note::
395395
Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`.
396+
397+
.. versionchanged:: 2022.12
398+
Added complex data type support.
396399
"""
397400

398401
def __float__(self: array, /) -> float:
@@ -746,7 +749,7 @@ def __mod__(self: array, other: Union[int, float, array], /) -> array:
746749
Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.remainder`.
747750
"""
748751

749-
def __mul__(self: array, other: Union[int, float, array], /) -> array:
752+
def __mul__(self: array, other: Union[int, float, complex, array], /) -> array:
750753
r"""
751754
Calculates the product for each element of an array instance with the respective element of the array ``other``.
752755
@@ -757,7 +760,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
757760
----------
758761
self: array
759762
array instance. Should have a numeric data type.
760-
other: Union[int, float, array]
763+
other: Union[int, float, complex, array]
761764
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
762765
763766
Returns
@@ -775,15 +778,15 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
775778
Added complex data type support.
776779
"""
777780

778-
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
781+
def __ne__(self: array, other: Union[int, float, complex, bool, array], /) -> array:
779782
"""
780783
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
781784
782785
Parameters
783786
----------
784787
self: array
785788
array instance. May have any data type.
786-
other: Union[int, float, bool, array]
789+
other: Union[int, float, complex, bool, array]
787790
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type.
788791
789792
Returns
@@ -852,9 +855,6 @@ def __or__(self: array, other: Union[int, bool, array], /) -> array:
852855
853856
.. note::
854857
Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_or`.
855-
856-
.. versionchanged:: 2022.12
857-
Added complex data type support.
858858
"""
859859

860860
def __pos__(self: array, /) -> array:
@@ -876,7 +876,7 @@ def __pos__(self: array, /) -> array:
876876
Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.positive`.
877877
"""
878878

879-
def __pow__(self: array, other: Union[int, float, array], /) -> array:
879+
def __pow__(self: array, other: Union[int, float, complex, array], /) -> array:
880880
r"""
881881
Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``.
882882
@@ -889,7 +889,7 @@ def __pow__(self: array, other: Union[int, float, array], /) -> array:
889889
----------
890890
self: array
891891
array instance whose elements correspond to the exponentiation base. Should have a numeric data type.
892-
other: Union[int, float, array]
892+
other: Union[int, float, complex, array]
893893
other array whose elements correspond to the exponentiation exponent. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
894894
895895
Returns
@@ -933,7 +933,7 @@ def __setitem__(
933933
key: Union[
934934
int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array
935935
],
936-
value: Union[int, float, bool, array],
936+
value: Union[int, float, complex, bool, array],
937937
/,
938938
) -> None:
939939
"""
@@ -947,7 +947,7 @@ def __setitem__(
947947
array instance.
948948
key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]
949949
index key.
950-
value: Union[int, float, bool, array]
950+
value: Union[int, float, complex, bool, array]
951951
value(s) to set. Must be compatible with ``self[key]`` (see :ref:`broadcasting`).
952952
953953
@@ -960,7 +960,7 @@ def __setitem__(
960960
When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined.
961961
"""
962962

963-
def __sub__(self: array, other: Union[int, float, array], /) -> array:
963+
def __sub__(self: array, other: Union[int, float, complex, array], /) -> array:
964964
"""
965965
Calculates the difference for each element of an array instance with the respective element of the array ``other``.
966966
@@ -970,7 +970,7 @@ def __sub__(self: array, other: Union[int, float, array], /) -> array:
970970
----------
971971
self: array
972972
array instance (minuend array). Should have a numeric data type.
973-
other: Union[int, float, array]
973+
other: Union[int, float, complex, array]
974974
subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
975975
976976
Returns
@@ -988,7 +988,7 @@ def __sub__(self: array, other: Union[int, float, array], /) -> array:
988988
Added complex data type support.
989989
"""
990990

991-
def __truediv__(self: array, other: Union[int, float, array], /) -> array:
991+
def __truediv__(self: array, other: Union[int, float, complex, array], /) -> array:
992992
r"""
993993
Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
994994
@@ -1001,7 +1001,7 @@ def __truediv__(self: array, other: Union[int, float, array], /) -> array:
10011001
----------
10021002
self: array
10031003
array instance. Should have a numeric data type.
1004-
other: Union[int, float, array]
1004+
other: Union[int, float, complex, array]
10051005
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
10061006
10071007
Returns

src/array_api_stubs/_2023_12/array_object.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def __abs__(self: array, /) -> array:
148148
Added complex data type support.
149149
"""
150150

151-
def __add__(self: array, other: Union[int, float, array], /) -> array:
151+
def __add__(self: array, other: Union[int, float, complex, array], /) -> array:
152152
"""
153153
Calculates the sum for each element of an array instance with the respective element of the array ``other``.
154154
155155
Parameters
156156
----------
157157
self: array
158158
array instance (augend array). Should have a numeric data type.
159-
other: Union[int, float, array]
159+
other: Union[int, float, complex, array]
160160
addend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
161161
162162
Returns
@@ -494,15 +494,15 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
494494
ONE_API = 14
495495
"""
496496

497-
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
497+
def __eq__(self: array, other: Union[int, float, complex, bool, array], /) -> array:
498498
r"""
499499
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
500500
501501
Parameters
502502
----------
503503
self: array
504504
array instance. May have any data type.
505-
other: Union[int, float, bool, array]
505+
other: Union[int, float, complex, bool, array]
506506
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type.
507507
508508
Returns
@@ -513,6 +513,9 @@ def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
513513
514514
.. note::
515515
Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`.
516+
517+
.. versionchanged:: 2022.12
518+
Added complex data type support.
516519
"""
517520

518521
def __float__(self: array, /) -> float:
@@ -893,7 +896,7 @@ def __mod__(self: array, other: Union[int, float, array], /) -> array:
893896
Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.remainder`.
894897
"""
895898

896-
def __mul__(self: array, other: Union[int, float, array], /) -> array:
899+
def __mul__(self: array, other: Union[int, float, complex, array], /) -> array:
897900
r"""
898901
Calculates the product for each element of an array instance with the respective element of the array ``other``.
899902
@@ -904,7 +907,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
904907
----------
905908
self: array
906909
array instance. Should have a numeric data type.
907-
other: Union[int, float, array]
910+
other: Union[int, float, complex, array]
908911
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
909912
910913
Returns
@@ -922,15 +925,15 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
922925
Added complex data type support.
923926
"""
924927

925-
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
928+
def __ne__(self: array, other: Union[int, float, complex, bool, array], /) -> array:
926929
"""
927930
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
928931
929932
Parameters
930933
----------
931934
self: array
932935
array instance. May have any data type.
933-
other: Union[int, float, bool, array]
936+
other: Union[int, float, complex, bool, array]
934937
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type.
935938
936939
Returns
@@ -1024,7 +1027,7 @@ def __pos__(self: array, /) -> array:
10241027
Added complex data type support.
10251028
"""
10261029

1027-
def __pow__(self: array, other: Union[int, float, array], /) -> array:
1030+
def __pow__(self: array, other: Union[int, float, complex, array], /) -> array:
10281031
r"""
10291032
Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``.
10301033
@@ -1037,7 +1040,7 @@ def __pow__(self: array, other: Union[int, float, array], /) -> array:
10371040
----------
10381041
self: array
10391042
array instance whose elements correspond to the exponentiation base. Should have a numeric data type.
1040-
other: Union[int, float, array]
1043+
other: Union[int, float, complex, array]
10411044
other array whose elements correspond to the exponentiation exponent. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
10421045
10431046
Returns
@@ -1081,7 +1084,7 @@ def __setitem__(
10811084
key: Union[
10821085
int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array
10831086
],
1084-
value: Union[int, float, bool, array],
1087+
value: Union[int, float, complex, bool, array],
10851088
/,
10861089
) -> None:
10871090
"""
@@ -1095,7 +1098,7 @@ def __setitem__(
10951098
array instance.
10961099
key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]
10971100
index key.
1098-
value: Union[int, float, bool, array]
1101+
value: Union[int, float, complex, bool, array]
10991102
value(s) to set. Must be compatible with ``self[key]`` (see :ref:`broadcasting`).
11001103
11011104
@@ -1108,7 +1111,7 @@ def __setitem__(
11081111
When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined.
11091112
"""
11101113

1111-
def __sub__(self: array, other: Union[int, float, array], /) -> array:
1114+
def __sub__(self: array, other: Union[int, float, complex, array], /) -> array:
11121115
"""
11131116
Calculates the difference for each element of an array instance with the respective element of the array ``other``.
11141117
@@ -1118,7 +1121,7 @@ def __sub__(self: array, other: Union[int, float, array], /) -> array:
11181121
----------
11191122
self: array
11201123
array instance (minuend array). Should have a numeric data type.
1121-
other: Union[int, float, array]
1124+
other: Union[int, float, complex, array]
11221125
subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
11231126
11241127
Returns
@@ -1136,7 +1139,7 @@ def __sub__(self: array, other: Union[int, float, array], /) -> array:
11361139
Added complex data type support.
11371140
"""
11381141

1139-
def __truediv__(self: array, other: Union[int, float, array], /) -> array:
1142+
def __truediv__(self: array, other: Union[int, float, complex, array], /) -> array:
11401143
r"""
11411144
Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
11421145
@@ -1149,7 +1152,7 @@ def __truediv__(self: array, other: Union[int, float, array], /) -> array:
11491152
----------
11501153
self: array
11511154
array instance. Should have a numeric data type.
1152-
other: Union[int, float, array]
1155+
other: Union[int, float, complex, array]
11531156
other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type.
11541157
11551158
Returns

src/array_api_stubs/_draft/array_object.py

+3
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@ def __eq__(self: array, other: Union[int, float, complex, bool, array], /) -> ar
515515
516516
- Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`.
517517
- Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent.
518+
519+
.. versionchanged:: 2022.12
520+
Added complex data type support.
518521
"""
519522

520523
def __float__(self: array, /) -> float:

0 commit comments

Comments
 (0)