|
9 | 9 | import optype.numpy._compat as _x |
10 | 10 | from optype._core._utils import set_module |
11 | 11 |
|
12 | | -from ._dtype import DType |
13 | | -from ._shape import AtLeast0D |
14 | | - |
15 | 12 |
|
16 | 13 | if sys.version_info >= (3, 13): |
17 | 14 | from typing import Protocol, Self, TypeAliasType, TypeVar, runtime_checkable |
|
33 | 30 | "Array1D", |
34 | 31 | "Array1D", |
35 | 32 | "Array2D", |
| 33 | + "Array3D", |
36 | 34 | "ArrayND", |
37 | 35 | "CanArray", |
| 36 | + "CanArray1D", |
| 37 | + "CanArray2D", |
| 38 | + "CanArray3D", |
38 | 39 | "CanArrayFinalize", |
39 | 40 | "CanArrayND", |
40 | 41 | "CanArrayWrap", |
|
43 | 44 | ] |
44 | 45 |
|
45 | 46 |
|
46 | | -_NDT = TypeVar("_NDT", bound=AtLeast0D, default=AtLeast0D) |
47 | | -_NDT_co = TypeVar("_NDT_co", bound=AtLeast0D, default=AtLeast0D, covariant=True) |
48 | | -_DTT = TypeVar("_DTT", bound=DType, default=DType) |
49 | | -_DTT_co = TypeVar("_DTT_co", bound=DType, default=DType, covariant=True) |
| 47 | +_NDT = TypeVar("_NDT", bound=tuple[int, ...], default=tuple[int, ...]) |
| 48 | +_NDT_co = TypeVar( |
| 49 | + "_NDT_co", |
| 50 | + bound=tuple[int, ...], |
| 51 | + default=tuple[int, ...], |
| 52 | + covariant=True, |
| 53 | +) |
| 54 | +_DTT = TypeVar("_DTT", bound=np.dtype[np.generic], default=np.dtype[np.generic]) |
| 55 | +_DTT_co = TypeVar( |
| 56 | + "_DTT_co", |
| 57 | + bound=np.dtype[np.generic], |
| 58 | + default=np.dtype[np.generic], |
| 59 | + covariant=True, |
| 60 | +) |
50 | 61 | _SCT = TypeVar("_SCT", bound=np.generic, default=np.generic) |
51 | 62 | _SCT_co = TypeVar("_SCT_co", bound=np.generic, default=np.generic, covariant=True) |
52 | 63 |
|
@@ -115,6 +126,45 @@ class CanArray(Protocol[_NDT, _DTT_co]): |
115 | 126 | def __array__(self, /) -> np.ndarray[_NDT, _DTT_co]: ... |
116 | 127 |
|
117 | 128 |
|
| 129 | +@runtime_checkable |
| 130 | +@set_module("optype.numpy") |
| 131 | +class CanArrayND(Protocol[_SCT_co]): |
| 132 | + """ |
| 133 | + Similar to `optype.numpy.CanArray`, but must be sized (i.e. excludes scalars), |
| 134 | + and is parameterized by only the scalar type (instead of the shape and dtype). |
| 135 | + """ |
| 136 | + |
| 137 | + def __len__(self, /) -> int: ... |
| 138 | + def __array__(self, /) -> np.ndarray[tuple[int, ...], np.dtype[_SCT_co]]: ... |
| 139 | + |
| 140 | + |
| 141 | +@runtime_checkable |
| 142 | +@set_module("optype.numpy") |
| 143 | +class CanArray1D(Protocol[_SCT_co]): |
| 144 | + """The 1-d variant of `optype.numpy.CanArrayND`.""" |
| 145 | + |
| 146 | + def __len__(self, /) -> int: ... |
| 147 | + def __array__(self, /) -> np.ndarray[tuple[int], np.dtype[_SCT_co]]: ... |
| 148 | + |
| 149 | + |
| 150 | +@runtime_checkable |
| 151 | +@set_module("optype.numpy") |
| 152 | +class CanArray2D(Protocol[_SCT_co]): |
| 153 | + """The 2-d variant of `optype.numpy.CanArrayND`.""" |
| 154 | + |
| 155 | + def __len__(self, /) -> int: ... |
| 156 | + def __array__(self, /) -> np.ndarray[tuple[int, int], np.dtype[_SCT_co]]: ... |
| 157 | + |
| 158 | + |
| 159 | +@runtime_checkable |
| 160 | +@set_module("optype.numpy") |
| 161 | +class CanArray3D(Protocol[_SCT_co]): |
| 162 | + """The 2-d variant of `optype.numpy.CanArrayND`.""" |
| 163 | + |
| 164 | + def __len__(self, /) -> int: ... |
| 165 | + def __array__(self, /) -> np.ndarray[tuple[int, int, int], np.dtype[_SCT_co]]: ... |
| 166 | + |
| 167 | + |
118 | 168 | # this is almost always a `ndarray`, but setting a `bound` might break in some |
119 | 169 | # edge cases |
120 | 170 | _T_contra = TypeVar("_T_contra", contravariant=True, default=object) |
@@ -169,15 +219,3 @@ def __array_interface__(self, /) -> _ArrayInterfaceT_co: ... |
169 | 219 | class HasArrayPriority(Protocol): |
170 | 220 | @property |
171 | 221 | def __array_priority__(self, /) -> float: ... |
172 | | - |
173 | | - |
174 | | -@runtime_checkable |
175 | | -@set_module("optype.numpy") |
176 | | -class CanArrayND(Protocol[_SCT_co]): |
177 | | - """ |
178 | | - Similar to `optype.numpy.CanArray`, but must be sized (i.e. excludes scalars), |
179 | | - and is parameterized by only the scalar type (instead of the shape and dtype). |
180 | | - """ |
181 | | - |
182 | | - def __len__(self, /) -> int: ... |
183 | | - def __array__(self, /) -> np.ndarray[tuple[int, ...], np.dtype[_SCT_co]]: ... |
|
0 commit comments