diff --git a/pytensor/compile/rebuild.py b/pytensor/compile/rebuild.py index e69824534d..29933a51e0 100644 --- a/pytensor/compile/rebuild.py +++ b/pytensor/compile/rebuild.py @@ -174,7 +174,7 @@ def rebuild_collect_shared( shared_inputs = [] def clone_v_get_shared_updates(v, copy_inputs_over): - r"""Clones a variable and its inputs recursively until all are in `clone_d`. + r"""Clones a variable and its inputs until all are in `clone_d`. Also, it appends all `SharedVariable`\s met along the way to `shared_inputs` and their corresponding @@ -182,48 +182,57 @@ def clone_v_get_shared_updates(v, copy_inputs_over): `update_expr`. """ - # this co-recurses with clone_a assert v is not None - if v in clone_d: - return clone_d[v] - if v.owner: - owner = v.owner - if owner not in clone_d: - for i in owner.inputs: - clone_v_get_shared_updates(i, copy_inputs_over) - clone_node_and_cache( - owner, - clone_d, - strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, - ) - return clone_d.setdefault(v, v) - elif isinstance(v, SharedVariable): - if v not in shared_inputs: - shared_inputs.append(v) - if v.default_update is not None: - # Check that v should not be excluded from the default - # updates list - if no_default_updates is False or ( - isinstance(no_default_updates, list) and v not in no_default_updates - ): - # Do not use default_update if a "real" update was - # provided - if v not in update_d: - v_update = v.type.filter_variable( - v.default_update, allow_convert=False - ) - if not v.type.is_super(v_update.type): - raise TypeError( - "An update must have a type compatible with " - "the original shared variable" + # Iterative depth-first traversal; recursion exceeds Python's stack on deep graphs + stack = [v] + while stack: + var = stack.pop() + if var in clone_d: + continue + owner = var.owner + if owner is not None: + if owner not in clone_d: + pending = [i for i in owner.inputs if i not in clone_d] + if pending: + stack.append(var) + stack.extend(reversed(pending)) + continue + clone_node_and_cache( + owner, + clone_d, + strict=rebuild_strict, + clone_inner_graphs=clone_inner_graphs, + ) + clone_d.setdefault(var, var) + continue + if isinstance(var, SharedVariable): + if var not in shared_inputs: + shared_inputs.append(var) + if var.default_update is not None: + # Check that var should not be excluded from the default + # updates list + if no_default_updates is False or ( + isinstance(no_default_updates, list) + and var not in no_default_updates + ): + # Do not use default_update if a "real" update was + # provided + if var not in update_d: + var_update = var.type.filter_variable( + var.default_update, allow_convert=False ) - update_d[v] = v_update - update_expr.append((v, v_update)) - if not copy_inputs_over: - return clone_d.setdefault(v, v.clone()) - else: - return clone_d.setdefault(v, v) + if not var.type.is_super(var_update.type): + raise TypeError( + "An update must have a type compatible with " + "the original shared variable" + ) + update_d[var] = var_update + update_expr.append((var, var_update)) + if not copy_inputs_over: + clone_d.setdefault(var, var.clone()) + else: + clone_d.setdefault(var, var) + return clone_d[v] # initialize the clone_d mapping with the replace dictionary if replace is None: diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c0abd7bebe..4d2def0805 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -480,8 +480,15 @@ def numba_funcify_FunctionGraph( fgraph: AbstractFunctionGraph, node=None, fgraph_name="numba_funcified_fgraph", + ofg_memo=None, **kwargs, ): + # Memoize compiled OpFromGraph inner functions by Op, so repeated uses of + # equal OpFromGraphs (anywhere in the same compilation, including nested + # inner graphs) reuse a single compiled function + if ofg_memo is None: + ofg_memo = {} + kwargs["ofg_memo"] = ofg_memo # Collect cache keys of every Op/Constant in the FunctionGraph # so we can create a global cache key for the whole FunctionGraph fgraph_can_be_cached = [True] diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 3691e74b05..74f8dda91a 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -51,10 +51,17 @@ def string_deepcopy(x): @register_funcify_and_cache_key(OpFromGraph) def numba_funcify_OpFromGraph( - op, node=None, mode=NUMBA.excluding("symbolic_op_recognition"), **kwargs + op, + node=None, + mode=NUMBA.excluding("symbolic_op_recognition"), + ofg_memo=None, + **kwargs, ): _ = kwargs.pop("storage_map", None) + if ofg_memo is not None and op in ofg_memo: + return ofg_memo[op] + # Apply inner rewrites # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of OpFromGraph? @@ -70,7 +77,11 @@ def numba_funcify_OpFromGraph( output_specs = [Out(o, borrow=False) for o in fgraph.outputs] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( - fgraph, squeeze_output=True, fgraph_name="numba_ofg", **kwargs + fgraph, + squeeze_output=True, + fgraph_name="numba_ofg", + ofg_memo=ofg_memo, + **kwargs, ) if fgraph_cache_key is None: @@ -86,6 +97,9 @@ def numba_funcify_OpFromGraph( ).encode() ).hexdigest() + if ofg_memo is not None: + ofg_memo[op] = (fgraph_fn, ofg_cache_key) + return fgraph_fn, ofg_cache_key diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index e8d05ce7ff..cf177f6183 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -122,7 +122,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): inner_destroyed_untraced_out_idxs.add(untraced_start + j) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( - op.fgraph, fgraph_name="numba_scan" + op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") ) outer_in_names_to_vars = { diff --git a/pytensor/printing.py b/pytensor/printing.py index e3c0345b04..cd9b9bab1f 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -847,12 +847,12 @@ def _show_inner_graph(op): new_prefix_child = prefix + " " print("Inner graphs:", file=_file) - printed_inner_graphs_nodes = set() + printed_inner_graph_ops = set() for ig_var in inner_graph_vars: - if ig_var.owner in printed_inner_graphs_nodes: + if ig_var.owner.op in printed_inner_graph_ops: continue else: - printed_inner_graphs_nodes.add(ig_var.owner) + printed_inner_graph_ops.add(ig_var.owner.op) # This is a work-around to maintain backward compatibility # (e.g. to only print inner graphs that have been compiled through # a call to `Op.prepare_node`) @@ -889,27 +889,31 @@ def _show_inner_graph(op): print("", file=_file) - _debugprint( - ig_var, - prefix=prefix, - depth=depth, - done=done, - print_type=print_type, - print_shape=print_shape, - file=_file, - id_type=id_type, - inner_graph_ops=inner_graph_vars, - stop_on_name=stop_on_name, - inner_to_outer_inputs=inner_to_outer_inputs, - used_ids=used_ids, - op_information=op_information, - assumption_tags=assumption_tags, - parent_node=ig_var.owner, - print_op_info=print_op_info, - print_destroy_map=print_destroy_map, - print_view_map=print_view_map, - is_inner_graph_header=True, + # Header line: the Op, then a single "[id A, B, ...]" listing every + # node whose inner graph is this one (printed once below), then its + # destroy/view maps. Equal Ops have identical inner graphs, so + # membership is grouped by Op equality. It must be computed here, + # not before the loop: nodes nested inside other inner graphs are + # only discovered while printing the bodies above. Output is + # streamed, so a node of an equal Op discovered after this header + # has printed cannot be added to it retroactively. + op = ig_var.owner.op + id_strs = [ + _assign_id(node, used_ids, done, id_type, node.outputs[0]) + # A multi-output node appears once per output var; dedup nodes. + for node in dict.fromkeys( + v.owner for v in inner_graph_vars if v.owner.op == op + ) + ] + tokens = [ + s[4:-1] for s in id_strs if s.startswith("[id ") and s.endswith("]") + ] + ids_str = f" [id {', '.join(tokens)}]" if tokens else "" + destroy_map_str = ( + f" d={op.destroy_map}" if print_destroy_map and op.destroy_map else "" ) + view_map_str = f" v={op.view_map}" if print_view_map and op.view_map else "" + print(f"{op}{ids_str}{destroy_map_str}{view_map_str}", file=_file) if print_fgraph_inputs: for inp in inner_inputs: diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 19185534f0..130cbf9ba2 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -965,7 +965,9 @@ def root_scalar( class RootOp(ScipyVectorWrapperOp): - __props__ = ("method", "jac") + # These __props__ were wrong: they ignore the inner graph, + # making RootOps of different equations compare equal (and get merged) + # __props__ = ("method", "jac") def __init__( self, diff --git a/tests/compile/test_rebuild.py b/tests/compile/test_rebuild.py new file mode 100644 index 0000000000..65928449d9 --- /dev/null +++ b/tests/compile/test_rebuild.py @@ -0,0 +1,16 @@ +import sys + +import pytensor.tensor as pt +from pytensor.compile.rebuild import rebuild_collect_shared + + +def test_rebuild_collect_shared_deep_graph(): + # Cloning must not recurse, or graphs deeper than the interpreter stack fail + x = pt.dscalar("x") + out = x + for i in range(sys.getrecursionlimit() + 500): + out = out + i + + input_variables, cloned_outputs, (clone_d, *_) = rebuild_collect_shared([out], [x]) + assert input_variables == [x] + assert cloned_outputs == [clone_d[out]] diff --git a/tests/graph/utils.py b/tests/graph/utils.py index 2e14fc79a4..5e1bca569a 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -131,7 +131,9 @@ def make_node(self, input): class MyInnerGraphOp(Op, HasInnerGraph): - __props__ = () + # No __props__: an inner-graph Op's identity is its inner graph, which a + # props-based __eq__/__hash__ would ignore (collapsing distinct instances). + # Falling back to object identity is the correct default for this mock. def __init__(self, inner_inputs, inner_outputs): input_replacements = [ diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index 73affe6c8f..b8b8f1b56b 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -121,6 +121,37 @@ def test_OpFromGraph(): compare_numba_and_py([x, y, z], [out], [xv, yv, zv]) +def test_ofg_inner_compilation_reused(monkeypatch): + from pytensor.link.numba.dispatch import compile_ops + + n_inner_compiles = 0 + inner_funcify = compile_ops.numba_funcify_and_cache_key + + def counting_funcify(*args, **kwargs): + nonlocal n_inner_compiles + n_inner_compiles += 1 + return inner_funcify(*args, **kwargs) + + monkeypatch.setattr(compile_ops, "numba_funcify_and_cache_key", counting_funcify) + + x = pt.vector("x") + inner = OpFromGraph([x], [pt.exp(x) + 1]) + inner_equiv = OpFromGraph([x], [pt.exp(x) + 1]) + outer = OpFromGraph([x], [inner(x) * 2]) + + y = pt.vector("y") + out = outer(inner(y)) + inner_equiv(y) + fn = function([y], out, mode="NUMBA") + # Only two distinct OpFromGraphs: `inner` (used directly, nested inside + # `outer`, and via the equivalent `inner_equiv`) and `outer` + assert n_inner_compiles == 2 + + y_test = np.array([0.0, 1.0], dtype=config.floatX) + inner_res = np.exp(y_test) + 1 + expected = (np.exp(inner_res) + 1) * 2 + inner_res + np.testing.assert_allclose(fn(y_test), expected, rtol=1e-6) + + @pytest.mark.filterwarnings("error") def test_ofg_inner_inplace(): x = pt.vector("x") diff --git a/tests/tensor/test_optimize.py b/tests/tensor/test_optimize.py index 3f452591b9..1919277a18 100644 --- a/tests/tensor/test_optimize.py +++ b/tests/tensor/test_optimize.py @@ -356,6 +356,21 @@ def root_fn(x, a): utt.verify_grad(root_fn, [x0, a_val], eps=1e-6) +def test_root_distinct_equations_not_merged(): + # Regression test: RootOp __props__ ignored the inner graph, so RootOps of + # different equations compared equal and MergeOptimizer collapsed them + x = pt.scalar("x") + a = pt.scalar("a") + + sq_root, _ = root(x**2 - a, x) + cube_root, _ = root(x**3 - a, x) + + func = pytensor.function([x, a], [sq_root, cube_root]) + sq_res, cube_res = func(1.5, 8.0) + np.testing.assert_allclose(sq_res, np.sqrt(8.0), rtol=1e-6) + np.testing.assert_allclose(cube_res, np.cbrt(8.0), rtol=1e-6) + + def test_root_system_of_equations(): x = pt.tensor("x", shape=(None,)) a = pt.tensor("a", shape=(None,)) diff --git a/tests/test_printing.py b/tests/test_printing.py index b0ff25450d..973c7c5888 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -2,6 +2,7 @@ Tests of printing functionality """ +import importlib.util import io import logging import re @@ -31,8 +32,11 @@ pp, pydotprint, ) +from pytensor.scalar import Composite, float64 from pytensor.tensor import as_tensor_variable +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.linalg import inv +from pytensor.tensor.math import maximum from pytensor.tensor.type import dmatrix, dvector, matrix from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable @@ -498,6 +502,114 @@ def test_debugprint_inner_graph(): assert exp_line.strip() == res_line.strip() +def test_debugprint_inner_graph_shared(): + """Inner-graph `Op`s that compare equal share a single printed body, whose + header lists every node id it applies to (deduped, instead of one body per + occurrence).""" + x = dvector("x") + + def relu_ofg(): + i = dvector("i") + return OpFromGraph([i], [maximum(i, 0)], inline=False, name="Relu") + + # `a` and `b` use distinct-but-equal OFG instances (identical inner graph); + # `c` uses a structurally different OFG. + a = relu_ofg()(x) + b = relu_ofg()(a) + out = OpFromGraph([x], [x + 1], inline=False, name="AddOne")(b) + + lines = debugprint(out, file="str").split("\n") + + exp_res = """AddOne{inline=False} [id A] + └─ Relu{inline=False} [id B] + └─ Relu{inline=False} [id C] + └─ x [id D] + +Inner graphs: + +AddOne{inline=False} [id A] + ← Add [id E] + ├─ i0 [id F] + └─ ExpandDims{axis=0} [id G] + └─ 1 [id H] + +Relu{inline=False} [id B, C] + ← Maximum [id I] + ├─ i0 [id F] + └─ ExpandDims{axis=0} [id J] + └─ 0 [id K] + """ + + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): + assert exp_line.strip() == res_line.strip() + + def add_one_composite(): + xs = float64("xs") + return Composite([xs], [xs + 1.0]) + + d = Elemwise(add_one_composite())(x) + e = Elemwise(add_one_composite())(d) + + lines = debugprint(e, file="str").split("\n") + + exp_res = """Composite{(i0 + 1.0)} [id A] + └─ Composite{(i0 + 1.0)} [id B] + └─ x [id C] + +Inner graphs: + +Composite{(i0 + 1.0)} [id A, B] + ← add [id D] + ├─ i0 [id E] + └─ 1.0 [id F] + """ + + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): + assert exp_line.strip() == res_line.strip() + + # An Op that only appears nested inside other inner graphs: its nodes are + # only discovered while the parent bodies are printed, and the shared + # header must still list every node id + i1 = dvector("i") + a_op = OpFromGraph([i1], [relu_ofg()(i1) + 1], inline=False, name="A") + i2 = dvector("i") + b_op = OpFromGraph([i2], [relu_ofg()(i2) * 2], inline=False, name="B") + + lines = debugprint(a_op(x) + b_op(x), file="str").split("\n") + + exp_res = """Add [id A] + ├─ A{inline=False} [id B] + │ └─ x [id C] + └─ B{inline=False} [id D] + └─ x [id C] + +Inner graphs: + +A{inline=False} [id B] + ← Add [id E] + ├─ Relu{inline=False} [id F] + │ └─ i0 [id G] + └─ ExpandDims{axis=0} [id H] + └─ 1 [id I] + +B{inline=False} [id D] + ← Mul [id J] + ├─ Relu{inline=False} [id K] + │ └─ i0 [id G] + └─ ExpandDims{axis=0} [id L] + └─ 2 [id M] + +Relu{inline=False} [id F, K] + ← Maximum [id N] + ├─ i0 [id G] + └─ ExpandDims{axis=0} [id O] + └─ 0 [id P] + """ + + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): + assert exp_line.strip() == res_line.strip() + + def test_get_var_by_id(): r1, r2 = MyVariable("v1"), MyVariable("v2") o1 = MyOp("op1")(r1, r2) @@ -573,6 +685,9 @@ def test_summary_with_profile_optimizer(): assert "Rewriter Profile" in s.getvalue() +@pytest.mark.skipif( + importlib.util.find_spec("rich") is None, reason="rich is not installed" +) class TestDebugprintRich: """Tests for debugprint(..., file="rich"). @@ -581,12 +696,12 @@ class TestDebugprintRich: construct the right tree structure and don't crash on various graph shapes. """ - rich = pytest.importorskip("rich") - def test_return_type(self): + from rich.tree import Tree + x = dvector("x") tree = debugprint(x.sum(), file="rich") - assert isinstance(tree, self.rich.tree.Tree) + assert isinstance(tree, Tree) def test_single_output_has_one_child(self): # One output variable → the hidden root should have exactly one child. @@ -794,8 +909,10 @@ def test_markup_escaping(self): mul_node = tree.children[0].children[0] assert "result" in str(mul_node.label) # Verify Rich can render the tree without raising a markup error. + from rich.console import Console + buf = io.StringIO() - console = self.rich.console.Console(file=buf, highlight=False) + console = Console(file=buf, highlight=False) console.print(tree) # raises MarkupError if escaping is broken def test_deep_shared_node_sentinel_depth(self):