diff --git a/test/test_value_array.py b/test/test_value_array.py index a69708e..390e2cb 100644 --- a/test/test_value_array.py +++ b/test/test_value_array.py @@ -293,3 +293,12 @@ def test_unique() -> None: unit = tu.MHz v_arr = xs * unit assert np.array_equal(v_arr.unique(), np.unique(xs) * unit) + + +def test_matmul() -> None: + a = np.random.random((3, 4)) + b = np.random.random((4, 3)) * tu.ns + c: tu.TimeArray = a @ b # type: ignore[assignment] + d: tu.TimeArray = b @ a # type: ignore[assignment] + assert c.allclose((a @ b[tu.us]) * tu.us) + assert d.allclose((b[tu.s] @ a) * tu.s) diff --git a/tunits/core/cython/with_unit_value_array.pyx b/tunits/core/cython/with_unit_value_array.pyx index 2e7fcb1..c97fb0f 100644 --- a/tunits/core/cython/with_unit_value_array.pyx +++ b/tunits/core/cython/with_unit_value_array.pyx @@ -116,6 +116,10 @@ class ValueArray(WithUnit): return self ** 2 if ufunc == np.reciprocal: return self.__rtruediv__(1) + if ufunc == np.matmul: + if isinstance(inputs[0], ValueArray): + return inputs[0].__matmul__(inputs[1]) + return inputs[1].__rmatmul__(inputs[0]) if ufunc in [ np.greater, @@ -133,7 +137,7 @@ class ValueArray(WithUnit): if self._is_dimensionless(): return getattr(ufunc, method)(*(np.asarray(x) for x in inputs), **kwargs) - raise NotImplemented + return NotImplemented @property def dtype(WithUnit self) -> np.dtype: @@ -160,4 +164,11 @@ class ValueArray(WithUnit): def to_proto(self, msg: Optional['tunits_pb2.ValueArray'] = None) -> 'tunits_pb2.ValueArray': ret = _ndarray_to_proto(self.value, msg) ret.units.extend(_units_to_proto(self.display_units)) - return ret \ No newline at end of file + return ret + + def __matmul__(WithUnit self, other: np.ndarray): + return self.__with_value(self.value @ other) + + + def __rmatmul__(WithUnit self, other: np.ndarray): + return self.__with_value(other @ self.value)