Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sympy.core.mul import Mul
from sympy.core.numbers import Float, Integer
from sympy.core.symbol import Symbol
from sympy.core.sympify import sympify
from sympy.core.sympify import sympify, SympifyError

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -1342,7 +1342,13 @@ def _parse_expression_str(self, expression_str: str) -> None:

Currently only linear functions are supported.
"""
expression = sympify(sanitize_name(expression_str))
try:
expression = sympify(sanitize_name(expression_str))
except SympifyError as e:
raise UserInputError(
f"Unable to parse derived parameter expression: "
f"{expression_str}. Error: {e}"
) from e
if isinstance(expression, (Float, Integer)):
raise UserInputError(
"Derived parameters must have at least one parameter in "
Expand Down
8 changes: 8 additions & 0 deletions ax/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ def test_SpecialCharMetricNames(self) -> None:
parsed = parse_objective_expression(names[0])
self.assertNotEqual(str(parsed), names[0])

def test_parse_objective_expression_sympify_error(self) -> None:
"""Test that unparseable expressions raise UserInputError."""
with self.assertRaisesRegex(
UserInputError,
"Unable to parse objective expression",
):
parse_objective_expression("m1 +* m2")

def test_UniqueId(self) -> None:
"""Test _unique_id used for sorting."""
obj = Objective(expression="m1", metric_name_to_signature={"m1": "m1"})
Expand Down
11 changes: 11 additions & 0 deletions ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,17 @@ def test_invalid_inputs(self) -> None:
name="x", parameter_type=ParameterType.FLOAT, expression_str="y ** 2"
)

# test unparseable expression
with self.assertRaisesRegex(
UserInputError,
"Unable to parse derived parameter expression",
):
DerivedParameter(
name="x",
parameter_type=ParameterType.FLOAT,
expression_str="a +* b",
)

def test_eq(self) -> None:
param2 = DerivedParameter(
name="x", parameter_type=ParameterType.FLOAT, expression_str="2.0 * a + 1.0"
Expand Down
7 changes: 6 additions & 1 deletion ax/utils/common/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]:
raise UserInputError("Objective expression string must not be empty.")

sanitized = sanitize_name(expression_str, sanitize_parens=True)
parsed = sympify(sanitized)
try:
parsed = sympify(sanitized)
except SympifyError as e:
raise UserInputError(
f"Unable to parse objective expression: {expression_str}. Error: {e}"
) from e

if isinstance(parsed, tuple):
if any(not isinstance(p, Expr) for p in parsed):
Expand Down
Loading