Skip to content

Commit 4680b01

Browse files
[SymForce] Add symforce.set_epsilon_to_symbol
It is moderately inconvenient to set epsilon to a symbol as is. This is because if one wants to set epsilont to a symbol, one might run ``` python import symforce.symbolic as sf import symforce symforce.set_epsilon(sf.Symbol("epsilon")) ``` The trouble with this, however, is that the file defining `Symbol` also defined `atan2`, `asin_safe`, and `acos_safe`, which themselves use the default epsilon. But since `symforce.set_epsilon` can only be called before the value of epsilon is used (a restriction we impose because default function arguments are bound at function definition time, not late bound), the above snippet produces an error. Now, we only import `symforce.symbolic` to have access to `Symbol`, and don't actually need those default epsilon using functions. A user could instead just import `sympy` or `symengine` directly, but that would be inconvenient and confusing (especially for such a common use case). So instead, I'm adding the function `symforce.set_epsilon_to_symbol` (along with `symforce.set_epsilon_to_number` and `symforce.set_epsilon_to_zero`). This function will handle the import of `Symbol` for the user, saving the user the trouble. The other `symforce.set_epsilon_to_XXX` functions aren't strictly necessary for anything, but in discussion with Aaron and Hayk, we thought this api seemed nicest. The default value for `symforce.set_epsilon_to_number` is supposed to be `sf.numeric_epsilon`, which posed a small problem (the same exact problem as `sf.Symbol`). To address this, I moved the definition of `numeric_epsilon` to `symforce/__init__.py`, then reexport it in `symforce.symbolic`. An alternative was to simply redefine `atan2`, `asin_safe`, and `acos_safe` whenever a new value of epsilon is set, allowing a user to set epsilon after importing `symforce.symbolic`, but in conversation it was deciding this was not something we wanted to allow. Topic: set_epsilon_to_symbol GitOrigin-RevId: 24d1a2389427316d21bd85322e43032ec7f7ae2e
1 parent dd817c0 commit 4680b01

File tree

3 files changed

+198
-9
lines changed

3 files changed

+198
-9
lines changed

symforce/__init__.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from types import ModuleType
1212
import typing as T
1313
import os
14+
import sys
1415
import warnings
1516

1617
# -------------------------------------------------------------------------------------------------
@@ -50,8 +51,6 @@ def set_log_level(log_level: str) -> None:
5051
logger.setLevel(getattr(logging, log_level.upper()))
5152

5253
# Only do this if already imported, in case users don't want to use any C++ binaries
53-
import sys
54-
5554
if "cc_sym" in sys.modules:
5655
import cc_sym
5756

@@ -80,8 +79,6 @@ def _find_symengine() -> ModuleType:
8079
8180
Returns the imported symengine module
8281
"""
83-
import sys
84-
8582
if "symengine" in sys.modules:
8683
return sys.modules["symengine"]
8784

@@ -235,6 +232,9 @@ def set_backend(name: str) -> None:
235232
# Default epsilon
236233
# --------------------------------------------------------------------------------
237234

235+
# Should match C++ default epsilon in epsilon.h
236+
numeric_epsilon = 10 * sys.float_info.epsilon
237+
238238

239239
class AlreadyUsedEpsilon(Exception):
240240
"""
@@ -248,7 +248,7 @@ class AlreadyUsedEpsilon(Exception):
248248
_have_used_epsilon = False
249249

250250

251-
def set_epsilon(new_epsilon: T.Any) -> None:
251+
def _set_epsilon(new_epsilon: T.Any) -> None:
252252
"""
253253
Set the default epsilon for SymForce
254254
@@ -266,3 +266,46 @@ def set_epsilon(new_epsilon: T.Any) -> None:
266266

267267
global _epsilon # pylint: disable=global-statement
268268
_epsilon = new_epsilon
269+
270+
271+
def set_epsilon_to_symbol(name: str = "epsilon") -> None:
272+
"""
273+
Set the default epsilon for Symforce to a Symbol.
274+
275+
This must be called before `symforce.symbolic` or other symbolic libraries have been imported.
276+
See `symforce.symbolic.epsilon` for more information.
277+
278+
Args:
279+
name: The name of the symbol for the new default epsilon to use
280+
"""
281+
if get_symbolic_api() == "sympy":
282+
import sympy
283+
elif get_symbolic_api() == "symengine":
284+
sympy = _find_symengine()
285+
else:
286+
raise InvalidSymbolicApiError(get_symbolic_api())
287+
288+
_set_epsilon(sympy.Symbol(name))
289+
290+
291+
def set_epsilon_to_number(value: T.Any = numeric_epsilon) -> None:
292+
"""
293+
Set the default epsilon for Symforce to a number.
294+
295+
This must be called before `symforce.symbolic` or other symbolic libraries have been imported.
296+
See `symforce.symbolic.epsilon` for more information.
297+
298+
Args:
299+
value: The new default epsilon to use
300+
"""
301+
_set_epsilon(value)
302+
303+
304+
def set_epsilon_to_zero() -> None:
305+
"""
306+
Set the default epsilon for Symforce to zero.
307+
308+
This must be called before `symforce.symbolic` or other symbolic libraries have been imported.
309+
See `symforce.symbolic.epsilon` for more information.
310+
"""
311+
_set_epsilon(0.0)

symforce/internal/symbolic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
# pylint: disable=ungrouped-imports
2929

3030
import contextlib
31-
import sys
32-
3331
import symforce
3432
from symforce import typing as T
3533
from symforce import logger
@@ -255,8 +253,7 @@
255253
# --------------------------------------------------------------------------------
256254

257255

258-
# Should match C++ default epsilon in epsilon.h
259-
numeric_epsilon = 10 * sys.float_info.epsilon
256+
from symforce import numeric_epsilon
260257

261258

262259
def epsilon() -> T.Any:

test/symforce_set_epsilon_test.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
import sys
7+
8+
from symforce.test_util import TestCase
9+
from symforce import typing as T
10+
11+
12+
def clear_symforce() -> None:
13+
"""
14+
Removes symforce modules from sys.modules. Ensures that all symforce
15+
code will be (re-)executed upon import.
16+
"""
17+
module_names = list(sys.modules.keys())
18+
for module in module_names:
19+
if module.startswith("symforce"):
20+
del sys.modules[module]
21+
22+
23+
class SymforceSetEpsilonTest(TestCase):
24+
"""
25+
Test symforce.set_epsilon functions.
26+
"""
27+
28+
saved_modules: T.List[T.Tuple[str, T.Any]] = []
29+
30+
@classmethod
31+
def setUpClass(cls) -> None:
32+
# NOTE(brad): This is necessary because while the test classes are run sequentially, it
33+
# seems they are all created before any are run. That means they can save references to
34+
# modules which might get deleted and reloaded by clear_symforce.
35+
#
36+
# Since some code relies on the original module still being available in sys.modules, we
37+
# need to save and restore the original modules to keep from breaking other tests.
38+
cls.saved_modules = []
39+
for module in sys.modules:
40+
if module.startswith("symforce"):
41+
cls.saved_modules.append((module, sys.modules[module]))
42+
43+
@classmethod
44+
def tearDownClass(cls) -> None:
45+
clear_symforce()
46+
47+
for key, module in cls.saved_modules:
48+
sys.modules[key] = module
49+
50+
def test_set_epsilon_to_zero(self) -> None:
51+
"""
52+
Assumes symforce.set_epsilon_to_number works.
53+
54+
Tests:
55+
symforce.set_epsilon_to_zero
56+
"""
57+
clear_symforce()
58+
with self.subTest(msg="Test set_epsilon_to_zero()"):
59+
import symforce
60+
61+
symforce.set_epsilon_to_number(4.4)
62+
symforce.set_epsilon_to_zero()
63+
import symforce.symbolic as sf
64+
65+
self.assertEqual(0.0, sf.epsilon())
66+
67+
clear_symforce()
68+
with self.subTest(msg="Test function properly raises AlreadyUsedEpsilon exception"):
69+
import symforce
70+
71+
with self.assertRaises(symforce.AlreadyUsedEpsilon):
72+
import symforce.symbolic as sf
73+
74+
sf.epsilon()
75+
symforce.set_epsilon_to_zero()
76+
77+
def test_set_epsilon_to_symbol(self) -> None:
78+
"""
79+
Assumes symforce.set_epsilon_to_number works.
80+
81+
Tests:
82+
symforce.set_epsilon_to_symbol
83+
"""
84+
clear_symforce()
85+
with self.subTest(msg="Test set_epsilon_to_symbol()"):
86+
import symforce
87+
88+
symforce.set_epsilon_to_number(4.4)
89+
symforce.set_epsilon_to_symbol()
90+
import symforce.symbolic as sf
91+
92+
self.assertEqual(sf.Symbol("epsilon"), sf.epsilon())
93+
94+
clear_symforce()
95+
with self.subTest(msg="Test set_epsilon_to_symbol(name)"):
96+
import symforce
97+
98+
symforce.set_epsilon_to_number(4.4)
99+
symforce.set_epsilon_to_symbol(name="alpha")
100+
import symforce.symbolic as sf
101+
102+
self.assertEqual(sf.Symbol("alpha"), sf.epsilon())
103+
104+
clear_symforce()
105+
with self.subTest(msg="Test function properly raises AlreadyUsedEpsilon exception"):
106+
import symforce
107+
108+
with self.assertRaises(symforce.AlreadyUsedEpsilon):
109+
import symforce.symbolic as sf
110+
111+
sf.epsilon()
112+
symforce.set_epsilon_to_number()
113+
114+
def test_set_epsilon_to_number(self) -> None:
115+
"""
116+
Tests:
117+
symforce.set_epsilon_to_number
118+
"""
119+
clear_symforce()
120+
with self.subTest(msg="Test set_epsilon_to_number()"):
121+
import symforce
122+
123+
symforce.set_epsilon_to_number()
124+
import symforce.symbolic as sf
125+
126+
self.assertEqual(sf.numeric_epsilon, sf.epsilon())
127+
128+
clear_symforce()
129+
with self.subTest(msg="Test set_epsilon_to_number(value)"):
130+
import symforce
131+
132+
symforce.set_epsilon_to_number(value=4.4)
133+
import symforce.symbolic as sf
134+
135+
self.assertEqual(4.4, sf.epsilon())
136+
137+
clear_symforce()
138+
with self.subTest(msg="Test function properly raises AlreadyUsedEpsilon exception"):
139+
import symforce
140+
141+
with self.assertRaises(symforce.AlreadyUsedEpsilon):
142+
import symforce.symbolic as sf
143+
144+
sf.epsilon()
145+
symforce.set_epsilon_to_symbol()
146+
147+
148+
if __name__ == "__main__":
149+
TestCase.main()

0 commit comments

Comments
 (0)