From 21d7b147da04693b0095f7f9443f65fd36ff31a9 Mon Sep 17 00:00:00 2001 From: Fabian Peddinghaus Date: Fri, 17 Apr 2026 21:23:36 +0200 Subject: [PATCH] [ARITH] Expose allow_override parameter in Python Analyzer.bind() The C++ Analyzer::Bind() already supports allow_override, but the FFI bridge always used the default (false). This change threads the optional argument through the FFI layer and the Python wrapper so callers can rebind variables without triggering an error. --- python/tvm/arith/analyzer.py | 12 ++++++++++-- src/arith/analyzer.cc | 5 +++-- tests/python/arith/test_arith_simplify.py | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index fc5d3c9aea04..ea70c4de3d0f 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -266,7 +266,12 @@ def can_prove( """ return self._can_prove(expr, strength) - def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None: + def bind( + self, + var: tirx.Var, + expr: tirx.PrimExpr | ir.Range, + allow_override: bool = False, + ) -> None: """Bind a variable to the expression. Parameters @@ -276,8 +281,11 @@ def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None: expr : Union[tirx.PrimExpr, ir.Range] The expression or the range to bind to. + + allow_override : bool + Whether to allow overriding an existing binding for the variable. """ - return self._bind(var, expr) + return self._bind(var, expr, allow_override) def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: """Create a constraint scope. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f823e9efca95..7f3734266bea 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -326,10 +326,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } else if (name == "bind") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + bool allow_override = args.size() >= 3 && args[2].cast(); if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value()); + self->Bind(args[0].cast(), opt_range.value(), allow_override); } else { - self->Bind(args[0].cast(), args[1].cast()); + self->Bind(args[0].cast(), args[1].cast(), allow_override); } }); } else if (name == "can_prove") { diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index b367735c1f36..d30109fc447c 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -134,6 +134,20 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_bind_allow_override(): + ana = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + + ana.bind(x, tvm.ir.Range(0, 10)) + ana.bind(x, tvm.ir.Range(0, 5), allow_override=True) + assert ana.can_prove(x < 5) + + with pytest.raises( + tvm.error.TVMError, match="Trying to update var 'x' with a different const bound" + ): + ana.bind(x, tvm.ir.Range(0, 3)) + + def test_simplify_floor_mod_with_linear_offset(): """ Test that the floor_mod is simplified correctly when the offset is linear.