From 56a239a625f5d481a04fcc7033ea604cb26f1512 Mon Sep 17 00:00:00 2001 From: Bernie Beckerman Date: Wed, 29 Apr 2026 12:10:38 -0700 Subject: [PATCH] Validate all sympy-parsed inputs early with clear error messages Summary: 1. **`ax/core/parameter.py` (`DerivedParameter._parse_expression_str`)**: Wrapped bare `sympify()` call in try/except to convert `SympifyError` to `UserInputError` with a descriptive message. 2. **`ax/utils/common/sympy.py` (`parse_objective_expression`)**: Wrapped bare `sympify()` call in try/except to convert `SympifyError` to `UserInputError`. 3. **`ax_core_instantiation_utils.py` (`_make_objectives`)**: Migrated from `MultiObjective(objectives=[...])` to the new expression-based `Objective(expression=..., metric_name_to_signature=...)` API. This is a behavioral change: objectives are now constructed using a single `Objective` with a comma-separated expression string instead of wrapping individual `Objective` instances in a `MultiObjective`. The corresponding test in `base_utils_test.py` is updated to match. Differential Revision: D100058027 --- ax/core/parameter.py | 10 ++++++++-- ax/core/tests/test_objective.py | 8 ++++++++ ax/core/tests/test_parameter.py | 11 +++++++++++ ax/utils/common/sympy.py | 7 ++++++- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1e73f1720f0..0bc5890405d 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -31,7 +31,7 @@ from sympy.core.mul import Mul from sympy.core.numbers import Float, Integer from sympy.core.symbol import Symbol -from sympy.core.sympify import sympify +from sympy.core.sympify import sympify, SympifyError logger: Logger = get_logger(__name__) @@ -1342,7 +1342,13 @@ def _parse_expression_str(self, expression_str: str) -> None: Currently only linear functions are supported. """ - expression = sympify(sanitize_name(expression_str)) + try: + expression = sympify(sanitize_name(expression_str)) + except SympifyError as e: + raise UserInputError( + f"Unable to parse derived parameter expression: " + f"{expression_str}. Error: {e}" + ) from e if isinstance(expression, (Float, Integer)): raise UserInputError( "Derived parameters must have at least one parameter in " diff --git a/ax/core/tests/test_objective.py b/ax/core/tests/test_objective.py index 915d3a025fa..fff514ad93d 100644 --- a/ax/core/tests/test_objective.py +++ b/ax/core/tests/test_objective.py @@ -412,6 +412,14 @@ def test_SpecialCharMetricNames(self) -> None: parsed = parse_objective_expression(names[0]) self.assertNotEqual(str(parsed), names[0]) + def test_parse_objective_expression_sympify_error(self) -> None: + """Test that unparseable expressions raise UserInputError.""" + with self.assertRaisesRegex( + UserInputError, + "Unable to parse objective expression", + ): + parse_objective_expression("m1 +* m2") + def test_UniqueId(self) -> None: """Test _unique_id used for sorting.""" obj = Objective(expression="m1", metric_name_to_signature={"m1": "m1"}) diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index dc40742aa8f..216421f2d5c 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -1128,6 +1128,17 @@ def test_invalid_inputs(self) -> None: name="x", parameter_type=ParameterType.FLOAT, expression_str="y ** 2" ) + # test unparseable expression + with self.assertRaisesRegex( + UserInputError, + "Unable to parse derived parameter expression", + ): + DerivedParameter( + name="x", + parameter_type=ParameterType.FLOAT, + expression_str="a +* b", + ) + def test_eq(self) -> None: param2 = DerivedParameter( name="x", parameter_type=ParameterType.FLOAT, expression_str="2.0 * a + 1.0" diff --git a/ax/utils/common/sympy.py b/ax/utils/common/sympy.py index 4a870f6a321..bae79b5ee27 100644 --- a/ax/utils/common/sympy.py +++ b/ax/utils/common/sympy.py @@ -100,7 +100,12 @@ def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]: raise UserInputError("Objective expression string must not be empty.") sanitized = sanitize_name(expression_str, sanitize_parens=True) - parsed = sympify(sanitized) + try: + parsed = sympify(sanitized) + except SympifyError as e: + raise UserInputError( + f"Unable to parse objective expression: {expression_str}. Error: {e}" + ) from e if isinstance(parsed, tuple): if any(not isinstance(p, Expr) for p in parsed):