Skip to content

Commit 95e698b

Browse files
fix __array_ufunc_ and support for scalers (#66)
2 parents 934ceb7 + 222d284 commit 95e698b

File tree

5 files changed

+33
-52
lines changed

5 files changed

+33
-52
lines changed

test/test_value_array.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,21 @@ def test_invalid_unit_raises_error() -> None:
268268
def test_format() -> None:
269269
u = tu.GHz * np.random.random(10)
270270
assert f'{u}' == str(u)
271+
272+
273+
def test_ufunc() -> None:
274+
x = np.float64(0.42)
275+
y = tu.GHz * np.arange(4)[1:]
276+
277+
assert np.multiply(x, y).allclose(np.multiply(y, x)) # type: ignore[union-attr]
278+
assert np.divide(x, y).allclose(np.divide(np.int64(1), np.divide(y, x))) # type: ignore[union-attr]
279+
280+
with pytest.raises(UnitMismatchError):
281+
_ = np.add(x, y)
282+
with pytest.raises(UnitMismatchError):
283+
_ = np.add(y, x)
284+
285+
with pytest.raises(UnitMismatchError):
286+
_ = np.subtract(x, y)
287+
with pytest.raises(UnitMismatchError):
288+
_ = np.subtract(y, x)

test_perf/test_value_performance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_perf_repr(a: Value) -> str:
7474
return repr(a)
7575

7676

77-
@perf_goal(avg_nanos=600)
77+
@perf_goal(avg_nanos=800)
7878
def test_perf_parse_atom() -> Value:
7979
return Value(1, 'kilogram')
8080

tunits/core/cython/derived_unit_data.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ __SI_DERIVED_UNITS = [
7878
DerivedUnitData('lx', 'lux', 'lm/m^2', use_prefixes=True),
7979
DerivedUnitData('Bq', 'becquerel', 'Hz', use_prefixes=True),
8080
DerivedUnitData('l', 'liter', 'm^3', exp10=-3, use_prefixes=True),
81-
DerivedUnitData('phi0', 'magnetic_flux_quantum', 'J*s/C', value=2.067833831170082e-15, use_prefixes=True),
81+
DerivedUnitData(
82+
'phi0', 'magnetic_flux_quantum', 'J*s/C', value=2.067833831170082e-15, use_prefixes=True
83+
),
8284
DerivedUnitData('eV', 'electron_volt', 'N*m', value=1.602176634e-19, use_prefixes=False),
8385
]
8486

@@ -117,7 +119,7 @@ __OTHER_UNITS = [
117119
DerivedUnitData('psi', 'pounds_per_square_inch', 'Pa', 6894.75729317),
118120
DerivedUnitData('bar', 'barometric_pressure', 'Pa', 1e5),
119121
DerivedUnitData('atm', 'atmospheric_pressure', 'Pa', 101325.0),
120-
DerivedUnitData('torr', 'mm_of_mercury', 'Pa', 101325.0/760),
122+
DerivedUnitData('torr', 'mm_of_mercury', 'Pa', 101325.0 / 760),
121123
]
122124

123125
# Units that aren't technically exact, but close enough for our purposes.

tunits/core/cython/with_unit.pyx

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ cdef class WithUnit:
199199
if unit_val is None:
200200
raise ValueError("Bad WithUnit scaling value: " + repr(value))
201201

202-
if isinstance(value, (int, float, np.ndarray)):
202+
if isinstance(value, (numbers.Number, np.ndarray)):
203203
self.value = unit_val.value * value
204204
self.conv = unit_val.conv
205205
self.base_units = unit_val.base_units
@@ -259,7 +259,7 @@ cdef class WithUnit:
259259
def __mul__(left, b) -> 'WithUnit':
260260
cdef WithUnit right
261261
try:
262-
if isinstance(b, (int, float)):
262+
if isinstance(b, numbers.Number):
263263
return left.__with_value(left.value * b)
264264
right = _in_WithUnit(b)
265265
if left._is_dimensionless() and right._is_dimensionless():
@@ -279,7 +279,7 @@ cdef class WithUnit:
279279
return NotImplemented
280280
def __rmul__(self, b):
281281
try:
282-
if isinstance(b, (int, float)):
282+
if isinstance(b, numbers.Number):
283283
return self.__with_value(self.value * b)
284284
return self * b
285285
except NotTUnitsLikeError:
@@ -288,7 +288,7 @@ cdef class WithUnit:
288288
def __truediv__(a, b):
289289
cdef WithUnit left, right
290290
try:
291-
if isinstance(b, (int, float)):
291+
if isinstance(b, numbers.Number):
292292
return a.__with_value(a.value / b)
293293
left = _in_WithUnit(a)
294294
right = _in_WithUnit(b)
@@ -322,7 +322,7 @@ cdef class WithUnit:
322322
cdef WithUnit left, right
323323
cdef double c
324324
try:
325-
if isinstance(b, (int, float)):
325+
if isinstance(b, numbers.Number):
326326
return a.value//b, _in_WithUnit(a.value % b)
327327
left = _in_WithUnit(a)
328328
right = _in_WithUnit(b)
@@ -388,12 +388,6 @@ cdef class WithUnit:
388388
def imag(WithUnit self):
389389
return self.__with_value(self.value.imag)
390390

391-
# def round(WithUnit self, unit):
392-
# try:
393-
# return self.in_units_of(unit, True)
394-
# except NotTUnitsLikeError:
395-
# return NotImplemented
396-
397391
def __int__(self):
398392
if self.base_units.unit_count != 0:
399393
raise UnitMismatchError(
@@ -673,22 +667,6 @@ cdef class WithUnit:
673667
def conjugate(self) -> 'WithUnit':
674668
return self.__with_value(self.value.conjugate())
675669

676-
# def floor(self, u):
677-
# cdef WithUnit converted
678-
# try:
679-
# converted = self.in_units_of(u, False)
680-
# return converted.__with_value(floor(converted.value))
681-
# except NotTUnitsLikeError:
682-
# return NotImplemented
683-
684-
# def ceil(self, u):
685-
# cdef WithUnit converted
686-
# try:
687-
# converted = self.in_units_of(u, False)
688-
# return converted.__with_value(ceil(converted.value))
689-
# except NotTUnitsLikeError:
690-
# return NotImplemented
691-
692670
def sign(self) -> int | np.ndarray:
693671
return np.sign(self.value_in_base_units())
694672

tunits/core/cython/with_unit_value_array.pyx

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,32 +95,15 @@ class ValueArray(WithUnit):
9595

9696
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
9797
if method == "__call__":
98+
is_value = isinstance(inputs[0], ValueArray)
9899
if ufunc == np.add:
99-
return inputs[0] + inputs[1]
100+
return inputs[0].__add__(inputs[1]) if is_value else inputs[1].__radd__(inputs[0])
100101
if ufunc == np.subtract:
101-
return inputs[0] - inputs[1]
102+
return inputs[0].__sub__(inputs[1]) if is_value else inputs[1].__rsub__(inputs[0])
102103
if ufunc == np.multiply:
103-
if not isinstance(inputs[0], np.ndarray) and not isinstance(inputs[1], np.ndarray):
104-
return inputs[0] * inputs[1]
105-
elif isinstance(inputs[0], np.ndarray):
106-
return inputs[1] * inputs[0]
107-
elif isinstance(inputs[1], np.ndarray):
108-
return inputs[0] * inputs[1]
109-
else:
110-
raise NotImplementedError(
111-
f"multiply not implemented for types {type(inputs[0])}, {type(inputs[1])}"
112-
)
104+
return inputs[0].__mul__(inputs[1]) if is_value else inputs[1].__rmul__(inputs[0])
113105
if ufunc == np.divide:
114-
if not isinstance(inputs[0], np.ndarray) and not isinstance(inputs[1], np.ndarray):
115-
return inputs[0] / inputs[1]
116-
elif isinstance(inputs[0], np.ndarray):
117-
return inputs[1].__rtruediv__(inputs[0])
118-
elif isinstance(inputs[1], np.ndarray):
119-
return inputs[0] / inputs[1]
120-
else:
121-
raise NotImplementedError(
122-
f"divide not implemented for types {type(inputs[0])}, {type(inputs[1])}"
123-
)
106+
return inputs[0].__truediv__(inputs[1]) if is_value else inputs[1].__rtruediv__(inputs[0])
124107
if ufunc == np.power:
125108
return inputs[0] ** inputs[1]
126109
if ufunc in [np.positive, np.negative, np.abs, np.fabs, np.conj]:

0 commit comments

Comments
 (0)