Skip to content

Commit 14e4cdb

Browse files
authored
Merge pull request #164 from neutrinoceros/bugfix_yt_GH_874_2
bugfix: fix commutativity in unyt_array operators
2 parents e957e4c + 9b72505 commit 14e4cdb

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

unyt/array.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
UnitOperationError,
124124
UnitConversionError,
125125
UnitsNotReducible,
126+
SymbolNotFoundError,
126127
)
127128
from unyt.equivalencies import equivalence_registry
128129
from unyt._on_demand_imports import _astropy, _pint
@@ -160,7 +161,13 @@ def _sqrt_unit(unit):
160161

161162
@lru_cache(maxsize=128, typed=False)
162163
def _multiply_units(unit1, unit2):
163-
ret = (unit1 * unit2).simplify()
164+
try:
165+
ret = (unit1 * unit2).simplify()
166+
except SymbolNotFoundError:
167+
# Some operators are not natively commutative when operands are
168+
# defined within different unit registries, and conversion
169+
# is defined one way but not the other.
170+
ret = (unit2 * unit1).simplify()
164171
return ret.as_coeff_unit()
165172

166173

@@ -195,7 +202,10 @@ def _square_unit(unit):
195202

196203
@lru_cache(maxsize=128, typed=False)
197204
def _divide_units(unit1, unit2):
198-
ret = (unit1 / unit2).simplify()
205+
try:
206+
ret = (unit1 / unit2).simplify()
207+
except SymbolNotFoundError:
208+
ret = (1 / (unit2 / unit1).simplify()).units
199209
return ret.as_coeff_unit()
200210

201211

unyt/tests/test_units.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pytest
2929
from sympy import Symbol
3030

31+
from unyt.array import unyt_quantity
3132
from unyt.testing import assert_allclose_units
3233
from unyt.unit_registry import UnitRegistry
3334
from unyt.dimensions import (
@@ -874,3 +875,17 @@ def test_degF():
874875
def test_delta_degF():
875876
a = 1 * Unit("delta_degF")
876877
assert str(a) == "1 Δ°F"
878+
879+
880+
def test_mixed_registry_operations():
881+
882+
reg = UnitRegistry(unit_system="cgs")
883+
reg.add("fake_length", 0.001, length)
884+
a = unyt_quantity(1, units="fake_length", registry=reg)
885+
b = unyt_quantity(1, "cm")
886+
887+
assert_almost_equal(a + b, b + a)
888+
assert_almost_equal(a - b, -(b - a))
889+
assert_almost_equal(a * b, b * a)
890+
assert_almost_equal(b / a, b / a.in_units("km"))
891+
assert_almost_equal(a / b, a / b.in_units("km"))

0 commit comments

Comments
 (0)