Skip to content

Commit c458236

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Validate all sympy-parsed inputs early with clear error messages (#5196)
Summary: 1. **`ax/core/parameter.py` (`DerivedParameter._parse_expression_str`)**: Wrapped bare `sympify()` call in try/except to convert `SympifyError` to `UserInputError` with a descriptive message. 2. **`ax/utils/common/sympy.py` (`parse_objective_expression`)**: Wrapped bare `sympify()` call in try/except to convert `SympifyError` to `UserInputError`. 3. **`ax_core_instantiation_utils.py` (`_make_objectives`)**: Migrated from `MultiObjective(objectives=[...])` to the new expression-based `Objective(expression=..., metric_name_to_signature=...)` API. This is a behavioral change: objectives are now constructed using a single `Objective` with a comma-separated expression string instead of wrapping individual `Objective` instances in a `MultiObjective`. The corresponding test in `base_utils_test.py` is updated to match. Differential Revision: D100058027
1 parent dd8b120 commit c458236

4 files changed

Lines changed: 33 additions & 3 deletions

File tree

ax/core/parameter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sympy.core.mul import Mul
3232
from sympy.core.numbers import Float, Integer
3333
from sympy.core.symbol import Symbol
34-
from sympy.core.sympify import sympify
34+
from sympy.core.sympify import sympify, SympifyError
3535

3636
logger: Logger = get_logger(__name__)
3737

@@ -1342,7 +1342,13 @@ def _parse_expression_str(self, expression_str: str) -> None:
13421342
13431343
Currently only linear functions are supported.
13441344
"""
1345-
expression = sympify(sanitize_name(expression_str))
1345+
try:
1346+
expression = sympify(sanitize_name(expression_str))
1347+
except SympifyError as e:
1348+
raise UserInputError(
1349+
f"Unable to parse derived parameter expression: "
1350+
f"{expression_str}. Error: {e}"
1351+
) from e
13461352
if isinstance(expression, (Float, Integer)):
13471353
raise UserInputError(
13481354
"Derived parameters must have at least one parameter in "

ax/core/tests/test_objective.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,14 @@ def test_SpecialCharMetricNames(self) -> None:
412412
parsed = parse_objective_expression(names[0])
413413
self.assertNotEqual(str(parsed), names[0])
414414

415+
def test_parse_objective_expression_sympify_error(self) -> None:
416+
"""Test that unparseable expressions raise UserInputError."""
417+
with self.assertRaisesRegex(
418+
UserInputError,
419+
"Unable to parse objective expression",
420+
):
421+
parse_objective_expression("m1 +* m2")
422+
415423
def test_UniqueId(self) -> None:
416424
"""Test _unique_id used for sorting."""
417425
obj = Objective(expression="m1", metric_name_to_signature={"m1": "m1"})

ax/core/tests/test_parameter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,17 @@ def test_invalid_inputs(self) -> None:
11281128
name="x", parameter_type=ParameterType.FLOAT, expression_str="y ** 2"
11291129
)
11301130

1131+
# test unparseable expression
1132+
with self.assertRaisesRegex(
1133+
UserInputError,
1134+
"Unable to parse derived parameter expression",
1135+
):
1136+
DerivedParameter(
1137+
name="x",
1138+
parameter_type=ParameterType.FLOAT,
1139+
expression_str="a +* b",
1140+
)
1141+
11311142
def test_eq(self) -> None:
11321143
param2 = DerivedParameter(
11331144
name="x", parameter_type=ParameterType.FLOAT, expression_str="2.0 * a + 1.0"

ax/utils/common/sympy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,12 @@ def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]:
100100
raise UserInputError("Objective expression string must not be empty.")
101101

102102
sanitized = sanitize_name(expression_str, sanitize_parens=True)
103-
parsed = sympify(sanitized)
103+
try:
104+
parsed = sympify(sanitized)
105+
except SympifyError as e:
106+
raise UserInputError(
107+
f"Unable to parse objective expression: {expression_str}. Error: {e}"
108+
) from e
104109

105110
if isinstance(parsed, tuple):
106111
if any(not isinstance(p, Expr) for p in parsed):

0 commit comments

Comments
 (0)