diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index fc5d3c9aea04..ea70c4de3d0f 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -266,7 +266,12 @@ def can_prove( """ return self._can_prove(expr, strength) - def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None: + def bind( + self, + var: tirx.Var, + expr: tirx.PrimExpr | ir.Range, + allow_override: bool = False, + ) -> None: """Bind a variable to the expression. Parameters @@ -276,8 +281,11 @@ def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None: expr : Union[tirx.PrimExpr, ir.Range] The expression or the range to bind to. + + allow_override : bool + Whether to allow overriding an existing binding for the variable. """ - return self._bind(var, expr) + return self._bind(var, expr, allow_override) def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: """Create a constraint scope. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f823e9efca95..7f3734266bea 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -326,10 +326,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } else if (name == "bind") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + bool allow_override = args.size() >= 3 && args[2].cast(); if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value()); + self->Bind(args[0].cast(), opt_range.value(), allow_override); } else { - self->Bind(args[0].cast(), args[1].cast()); + self->Bind(args[0].cast(), args[1].cast(), allow_override); } }); } else if (name == "can_prove") { diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index b367735c1f36..d30109fc447c 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -134,6 +134,20 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_bind_allow_override(): + ana = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + + ana.bind(x, tvm.ir.Range(0, 10)) + ana.bind(x, tvm.ir.Range(0, 5), allow_override=True) + assert ana.can_prove(x < 5) + + with pytest.raises( + tvm.error.TVMError, match="Trying to update var 'x' with a different const bound" + ): + ana.bind(x, tvm.ir.Range(0, 3)) + + def test_simplify_floor_mod_with_linear_offset(): """ Test that the floor_mod is simplified correctly when the offset is linear.