diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1e73f1720f0..0bc5890405d 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -31,7 +31,7 @@ from sympy.core.mul import Mul from sympy.core.numbers import Float, Integer from sympy.core.symbol import Symbol -from sympy.core.sympify import sympify +from sympy.core.sympify import sympify, SympifyError logger: Logger = get_logger(__name__) @@ -1342,7 +1342,13 @@ def _parse_expression_str(self, expression_str: str) -> None: Currently only linear functions are supported. """ - expression = sympify(sanitize_name(expression_str)) + try: + expression = sympify(sanitize_name(expression_str)) + except SympifyError as e: + raise UserInputError( + f"Unable to parse derived parameter expression: " + f"{expression_str}. Error: {e}" + ) from e if isinstance(expression, (Float, Integer)): raise UserInputError( "Derived parameters must have at least one parameter in " diff --git a/ax/core/tests/test_objective.py b/ax/core/tests/test_objective.py index 915d3a025fa..fff514ad93d 100644 --- a/ax/core/tests/test_objective.py +++ b/ax/core/tests/test_objective.py @@ -412,6 +412,14 @@ def test_SpecialCharMetricNames(self) -> None: parsed = parse_objective_expression(names[0]) self.assertNotEqual(str(parsed), names[0]) + def test_parse_objective_expression_sympify_error(self) -> None: + """Test that unparseable expressions raise UserInputError.""" + with self.assertRaisesRegex( + UserInputError, + "Unable to parse objective expression", + ): + parse_objective_expression("m1 +* m2") + def test_UniqueId(self) -> None: """Test _unique_id used for sorting.""" obj = Objective(expression="m1", metric_name_to_signature={"m1": "m1"}) diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index dc40742aa8f..216421f2d5c 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -1128,6 +1128,17 @@ def test_invalid_inputs(self) -> None: name="x", parameter_type=ParameterType.FLOAT, expression_str="y ** 2" ) + # test unparseable expression + with self.assertRaisesRegex( + UserInputError, + "Unable to parse derived parameter expression", + ): + DerivedParameter( + name="x", + parameter_type=ParameterType.FLOAT, + expression_str="a +* b", + ) + def test_eq(self) -> None: param2 = DerivedParameter( name="x", parameter_type=ParameterType.FLOAT, expression_str="2.0 * a + 1.0" diff --git a/ax/utils/common/sympy.py b/ax/utils/common/sympy.py index 4a870f6a321..bae79b5ee27 100644 --- a/ax/utils/common/sympy.py +++ b/ax/utils/common/sympy.py @@ -100,7 +100,12 @@ def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]: raise UserInputError("Objective expression string must not be empty.") sanitized = sanitize_name(expression_str, sanitize_parens=True) - parsed = sympify(sanitized) + try: + parsed = sympify(sanitized) + except SympifyError as e: + raise UserInputError( + f"Unable to parse objective expression: {expression_str}. Error: {e}" + ) from e if isinstance(parsed, tuple): if any(not isinstance(p, Expr) for p in parsed):