Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions scimath/units/assertion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
""" Utilities providing assertions to support unit tests involving UnitScalars
and UnitArrays.
"""
from nose.tools import assert_false, assert_true

from scimath.units.compare_units import unit_arrays_almost_equal, \
unit_scalars_almost_equal


def assert_unit_scalar_almost_equal(val1, val2, rtol=1.e-9, msg=None):
if msg is None:
msg = "{} and {} are not almost equal with precision {}"
msg = msg.format(val1, val2, rtol)

assert_true(unit_scalars_almost_equal(val1, val2, rtol=rtol), msg=msg)


def assert_unit_scalar_not_almost_equal(val1, val2, rtol=1.e-9, msg=None):
if msg is None:
msg = "{} and {} unexpectedly almost equal with precision {}"
msg = msg.format(val1, val2, rtol)

assert_false(unit_scalars_almost_equal(val1, val2, rtol=rtol), msg=msg)


def assert_unit_array_almost_equal(uarr1, uarr2, rtol=1e-9, msg=None):
if msg is None:
msg = "{} and {} are not almost equal with precision {}"
msg = msg.format(uarr1, uarr2, rtol)

assert_true(unit_arrays_almost_equal(uarr1, uarr2, rtol=rtol), msg=msg)


def assert_unit_array_not_almost_equal(uarr1, uarr2, rtol=1e-9, msg=None):
if msg is None:
msg = "{} and {} are almost equal with precision {}"
msg = msg.format(uarr1, uarr2, rtol)

assert_false(unit_arrays_almost_equal(uarr1, uarr2, rtol=rtol), msg=msg)
81 changes: 81 additions & 0 deletions scimath/units/compare_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
""" Utilities around unit comparisons.
"""
import numpy as np

from scimath.units.api import convert, UnitArray, UnitScalar
from scimath.units.unit import InvalidConversion


def unit_scalars_almost_equal(x1, x2, rtol=1e-9):
""" Returns whether 2 UnitScalars are almost equal.

More precisely, what is tested is whether abs(a1-a2) < rtol*abs(a2), where
a1=float(x1) and a2=float(x2) after conversion to x1's units.

Parameters
----------
x1 : UnitScalar
First unit scalar to compare.

x2 : UnitScalar
Second unit scalar to compare.

rtol : float
Relative precision of the comparison.
"""
if not isinstance(x1, UnitScalar):
msg = "x1 is supposed to be a UnitScalar but a {} was passed."
msg = msg.format(type(x1))
raise ValueError(msg)

if not isinstance(x2, UnitScalar):
msg = "x2 is supposed to be a UnitScalar but a {} was passed."
msg = msg.format(type(x2))
raise ValueError(msg)

a1 = float(x1)
try:
a2 = convert(float(x2), from_unit=x2.units, to_unit=x1.units)
except InvalidConversion:
return False
return np.abs(a1 - a2) < np.abs(rtol * a2)


def unit_arrays_almost_equal(uarr1, uarr2, rtol=1e-9):
""" Returns whether 2 UnitArrays are almost equal (must be the same shape).

More precisely, what is tested is whether abs(a1-a2) < rtol*abs(a2) for all
values in the arrays, once uarr2 has been converted to uarr1's units.

Parameters
----------
uarr1 : UnitArray
First unit array to compare.

uarr2 : UnitArray
Second unit array to compare.

rtol : float
Relative precision of the comparison.
"""
if not isinstance(uarr1, UnitArray):
msg = "uarr1 is supposed to be a UnitArray but a {} was passed."
msg = msg.format(type(uarr1))
raise ValueError(msg)

if not isinstance(uarr2, UnitArray):
msg = "uarr2 is supposed to be a UnitArray but a {} was passed."
msg = msg.format(type(uarr2))
raise ValueError(msg)

if uarr1.shape != uarr2.shape:
return False

a1 = np.array(uarr1)
try:
a2 = convert(np.array(uarr2), from_unit=uarr2.units,
to_unit=uarr1.units)
except InvalidConversion:
return False

return np.all(np.abs(a1 - a2) < np.abs(rtol * a2))
Empty file.
56 changes: 56 additions & 0 deletions scimath/units/tests/test_assertion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest import TestCase

from scimath.units.api import UnitArray, UnitScalar
from scimath.units.assertion_utils import assert_unit_array_almost_equal, \
assert_unit_scalar_almost_equal


class TestAssertUnitScalarEqual(TestCase):
def test_same_unit_scalar(self):
assert_unit_scalar_almost_equal(UnitScalar(1, units="s"),
UnitScalar(1, units="s"))

def test_equivalent_unit_scalar(self):
assert_unit_scalar_almost_equal(UnitScalar(1, units="m"),
UnitScalar(100, units="cm"))

def test_not_close(self):
with self.assertRaises(AssertionError):
assert_unit_scalar_almost_equal(UnitScalar(1, units="m"),
UnitScalar(1.1, units="m"))

def test_not_close_custom_msg(self):
a1 = UnitScalar(1, units="m")
a2 = UnitScalar(1.1, units="m")
with self.assertRaises(AssertionError):
assert_unit_scalar_almost_equal(a1, a2, rtol=1e-2, msg="BLAH")

def test_unit_scalar_non_default_rtol(self):
assert_unit_scalar_almost_equal(UnitScalar(1, units="m"),
UnitScalar(1.01, units="m"), rtol=1e-1)


class TestAssertUnitArrayEqual(TestCase):
def test_same_unit_array(self):
assert_unit_array_almost_equal(UnitArray([1, 2], units="s"),
UnitArray([1, 2], units="s"))

def test_equivalent_unit_array(self):
assert_unit_array_almost_equal(UnitArray([1, 2], units="m"),
UnitArray([100, 200], units="cm"))

def test_not_close(self):
a1 = UnitArray([1.01, 2], units="s")
a2 = UnitArray([1, 2], units="s")
with self.assertRaises(AssertionError):
assert_unit_array_almost_equal(a1, a2)

def test_not_close_custom_msg(self):
a1 = UnitArray([1.01, 2], units="s")
a2 = UnitArray([1, 2], units="s")
with self.assertRaises(AssertionError):
assert_unit_array_almost_equal(a1, a2, msg="BLAH")

def test_unit_scalar_non_default_rtol(self):
assert_unit_array_almost_equal(UnitScalar(1, units="m"),
UnitScalar(1.01, units="m"), rtol=1e-1)
115 changes: 115 additions & 0 deletions scimath/units/tests/test_compare_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from unittest import TestCase

from scimath.units.api import dimensionless, UnitArray, UnitScalar
from scimath.units.compare_units import unit_arrays_almost_equal, \
unit_scalars_almost_equal


class TestUnitScalarAlmostEqual(TestCase):
def test_values_identical(self):
val1 = UnitScalar(1., units="m")
self.assertTrue(unit_scalars_almost_equal(val1, val1))

def test_wrong_arg_type1(self):
val1 = 1
val2 = UnitScalar(1., units="m")
with self.assertRaises(ValueError):
unit_scalars_almost_equal(val1, val2)

def test_wrong_arg_type2(self):
val1 = UnitScalar(1., units="m")
val2 = 1
with self.assertRaises(ValueError):
unit_scalars_almost_equal(val1, val2)

def test_values_not_close(self):
val1 = UnitScalar(1., units="m")
val2 = UnitScalar(1.1, units="m")
self.assertFalse(unit_scalars_almost_equal(val1, val2))

val2 = UnitScalar(1.00001, units="m")
self.assertFalse(unit_scalars_almost_equal(val1, val2))

def test_values_identical_in_diff_units(self):
val1 = UnitScalar(1., units="m")
val2 = UnitScalar(100., units="cm")
self.assertTrue(unit_scalars_almost_equal(val1, val2))

def test_dimensionless(self):
val1 = UnitScalar(1., units=dimensionless)
val2 = UnitScalar(1., units="cm")
self.assertFalse(unit_scalars_almost_equal(val1, val2))

def test_2_dimensionless(self):
val1 = UnitScalar(1., units=dimensionless)
val2 = UnitScalar(1., units="BLAH")
val3 = UnitScalar(100., units="BLAH")
self.assertTrue(unit_scalars_almost_equal(val1, val1))
self.assertTrue(unit_scalars_almost_equal(val1, val2))
self.assertFalse(unit_scalars_almost_equal(val1, val3))

def test_values_close_enough(self):
val1 = UnitScalar(1., units="m")
val2 = val1 + UnitScalar(1.e-5, units="m")
self.assertFalse(unit_scalars_almost_equal(val1, val2))
self.assertTrue(unit_scalars_almost_equal(val1, val2, rtol=1e-4))


class TestUnitArraysAlmostEqual(TestCase):
def test_wrong_argument_type1(self):
val1 = 1
val2 = UnitArray([1.], units="m")
with self.assertRaises(ValueError):
unit_arrays_almost_equal(val1, val2)

def test_wrong_argument_type2(self):
val1 = UnitArray([1.], units="m")
val2 = 1
with self.assertRaises(ValueError):
unit_arrays_almost_equal(val1, val2)

def test_different_shape(self):
val1 = UnitArray([1.], units="m")
val2 = UnitArray([1., 2.], units="m")
self.assertFalse(unit_arrays_almost_equal(val1, val2))

def test_not_close_default_rtol(self):
val1 = UnitArray([1., 2.], units="m")
val2 = UnitArray([1., 2.1], units="m")
self.assertFalse(unit_arrays_almost_equal(val1, val2))

val2 = UnitArray([1., 2.000001], units="m")
self.assertFalse(unit_arrays_almost_equal(val1, val2))

def test_values_identical(self):
val1 = UnitArray([1., 2.], units="m")
self.assertTrue(unit_arrays_almost_equal(val1, val1))

def test_values_identical_in_diff_units(self):
val1 = UnitArray([1., 2.], units="m")
val2 = UnitArray([100., 200.], units="cm")
self.assertTrue(unit_arrays_almost_equal(val1, val2))

def test_dimensionless(self):
val1 = UnitArray([1.], units=dimensionless)
val2 = UnitArray([1.], units="cm")
self.assertFalse(unit_arrays_almost_equal(val1, val2))

def test_2_dimensionless(self):
val1 = UnitArray([1.], units=dimensionless)
val2 = UnitArray([1.], units="BLAH")
val3 = UnitArray([100.], units="BLAH")
self.assertTrue(unit_arrays_almost_equal(val1, val1))
self.assertTrue(unit_arrays_almost_equal(val1, val2))
self.assertFalse(unit_arrays_almost_equal(val1, val3))

def test_values_close_enough(self):
val1 = UnitArray([1., 2.], units="m")
val2 = val1 + UnitArray([1.e-5, 1.e-6], units="m")
self.assertFalse(unit_arrays_almost_equal(val1, val2))
self.assertTrue(unit_arrays_almost_equal(val1, val2, rtol=1e-4))

def test_values_not_close_enough(self):
val1 = UnitArray([1., 2.], units="m")
val3 = val1 + UnitArray([1.e-2, 1.e-6], units="m")
self.assertFalse(unit_arrays_almost_equal(val1, val3, rtol=1e-4))