Skip to content

Commit 66d0d97

Browse files
support string comparison (#61)
2 parents 5cea66f + 5b91fa2 commit 66d0d97

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

test/cython/test_with_unit.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import pytest
2020

21-
from tunits.core import raw_WithUnit, raw_UnitArray, WithUnit
21+
from tunits.core import raw_WithUnit, raw_UnitArray, WithUnit, NotTUnitsLikeError
2222

2323
from tunits import UnitMismatchError, ValueArray, Value
2424
from test.test_utils import frac, conv, val
@@ -85,7 +85,6 @@ def test_abs() -> None:
8585

8686
def test_equality() -> None:
8787
equivalence_groups: list[list[Any]] = [
88-
[""],
8988
["other types"],
9089
[list],
9190
[None],
@@ -488,21 +487,11 @@ def test_get_item() -> None:
488487
v = val(2, conv(numer=3, denom=5, exp10=7), mps, kph)
489488

490489
# Wrong kinds of index (unit array, slice).
491-
with pytest.raises(TypeError):
490+
with pytest.raises(UnitMismatchError):
492491
_ = u[mps]
493-
with pytest.raises(TypeError):
492+
with pytest.raises(NotTUnitsLikeError):
494493
_ = u[1:2]
495494

496-
# Safety against dimensionless unit ambiguity.
497-
_ = u[u]
498-
with pytest.raises(TypeError):
499-
_ = u[1.0]
500-
with pytest.raises(TypeError):
501-
_ = u[1.0]
502-
with pytest.raises(TypeError):
503-
_ = u[1]
504-
with pytest.raises(TypeError):
505-
_ = u[2 * v / v]
506495
assert u[v / v] == 10
507496

508497
# Wrong unit.

test/test_value.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,7 @@ def test_division() -> None:
177177
def test_get_item() -> None:
178178
from tunits.units import ns, s
179179

180-
with pytest.raises(TypeError):
181-
_ = (ns / s)[2 * s / ns]
182-
with pytest.raises(TypeError):
183-
_ = (ns / s)[Value(3, '')]
180+
assert (ns / s)[Value(3, '')] == pytest.approx(1 / 3 * 1e-9)
184181
assert Value(1, '')[Value(1, '')] == 1
185182
assert Value(1, '')[ns / s] == 10**9
186183

tunits/core/cython/with_unit.pyx

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ def _in_WithUnit(obj) -> raw_WithUnit:
6767
"""
6868
if isinstance(obj, WithUnit):
6969
return obj
70-
return raw_WithUnit(obj, identity_conversion(), _EmptyUnit, _EmptyUnit, Value, ValueArray)
70+
try:
71+
return raw_WithUnit(obj, identity_conversion(), _EmptyUnit, _EmptyUnit, Value, ValueArray)
72+
except NotTUnitsLikeError as e:
73+
if isinstance(obj, str):
74+
return _try_interpret_as_with_unit(obj)
75+
raise e
76+
7177

7278
cdef _is_dimensionless_zero(WithUnit u):
7379
return (u._is_dimensionless() and
@@ -575,6 +581,20 @@ cdef class WithUnit:
575581
return self.__with_value(self.value[key])
576582
except NotTUnitsLikeError:
577583
return NotImplemented
584+
except TypeError:
585+
try:
586+
unit_val = _try_interpret_as_with_unit(str(key), True)
587+
except:
588+
raise NotTUnitsLikeError("Bad unit key: " + repr(key))
589+
if unit_val is None:
590+
raise NotTUnitsLikeError("Bad unit key: " + repr(key))
591+
if self.base_units != unit_val.base_units:
592+
raise UnitMismatchError("'%s' doesn't match '%s'." %
593+
(self, key))
594+
return (self.value
595+
* conversion_to_double(conversion_div(self.conv, unit_val.conv))
596+
/ unit_val.value)
597+
578598

579599
def __iter__(self):
580600
# Hack: We want calls to 'iter' to see that __iter__ exists and try to
@@ -688,6 +708,10 @@ cdef class WithUnit:
688708
def _from_json_dict_(cls, **kwargs):
689709
return cls(kwargs["value"], kwargs["unit"])
690710

711+
def _resolved_value_(self) -> WithUnit:
712+
"""Follows the cirq ResolvableValue protocol."""
713+
return self
714+
691715
def __getstate__(self):
692716
return {
693717
'value': self.value,

0 commit comments

Comments
 (0)