Skip to content

Commit 71b29d2

Browse files
return NotImplemted for unsupported operations in __array_ufunc__ and implement matmul (#84)
2 parents 6acfa57 + 080f7ca commit 71b29d2

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

test/test_value_array.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,12 @@ def test_unique() -> None:
293293
unit = tu.MHz
294294
v_arr = xs * unit
295295
assert np.array_equal(v_arr.unique(), np.unique(xs) * unit)
296+
297+
298+
def test_matmul() -> None:
299+
a = np.random.random((3, 4))
300+
b = np.random.random((4, 3)) * tu.ns
301+
c: tu.TimeArray = a @ b # type: ignore[assignment]
302+
d: tu.TimeArray = b @ a # type: ignore[assignment]
303+
assert c.allclose((a @ b[tu.us]) * tu.us)
304+
assert d.allclose((b[tu.s] @ a) * tu.s)

tunits/core/cython/with_unit_value_array.pyx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class ValueArray(WithUnit):
116116
return self ** 2
117117
if ufunc == np.reciprocal:
118118
return self.__rtruediv__(1)
119+
if ufunc == np.matmul:
120+
if isinstance(inputs[0], ValueArray):
121+
return inputs[0].__matmul__(inputs[1])
122+
return inputs[1].__rmatmul__(inputs[0])
119123

120124
if ufunc in [
121125
np.greater,
@@ -133,7 +137,7 @@ class ValueArray(WithUnit):
133137
if self._is_dimensionless():
134138
return getattr(ufunc, method)(*(np.asarray(x) for x in inputs), **kwargs)
135139

136-
raise NotImplemented
140+
return NotImplemented
137141

138142
@property
139143
def dtype(WithUnit self) -> np.dtype:
@@ -160,4 +164,11 @@ class ValueArray(WithUnit):
160164
def to_proto(self, msg: Optional['tunits_pb2.ValueArray'] = None) -> 'tunits_pb2.ValueArray':
161165
ret = _ndarray_to_proto(self.value, msg)
162166
ret.units.extend(_units_to_proto(self.display_units))
163-
return ret
167+
return ret
168+
169+
def __matmul__(WithUnit self, other: np.ndarray):
170+
return self.__with_value(self.value @ other)
171+
172+
173+
def __rmatmul__(WithUnit self, other: np.ndarray):
174+
return self.__with_value(other @ self.value)

0 commit comments

Comments
 (0)