Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 50 additions & 41 deletions pytensor/compile/rebuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,56 +174,65 @@ def rebuild_collect_shared(
shared_inputs = []

def clone_v_get_shared_updates(v, copy_inputs_over):
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
r"""Clones a variable and its inputs until all are in `clone_d`.

Also, it appends all `SharedVariable`\s met along the way to
`shared_inputs` and their corresponding
`SharedVariable.default_update`\s (when applicable) to `update_d` and
`update_expr`.

"""
# this co-recurses with clone_a
assert v is not None
if v in clone_d:
return clone_d[v]
if v.owner:
owner = v.owner
if owner not in clone_d:
for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over)
clone_node_and_cache(
owner,
clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
)
return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
shared_inputs.append(v)
if v.default_update is not None:
# Check that v should not be excluded from the default
# updates list
if no_default_updates is False or (
isinstance(no_default_updates, list) and v not in no_default_updates
):
# Do not use default_update if a "real" update was
# provided
if v not in update_d:
v_update = v.type.filter_variable(
v.default_update, allow_convert=False
)
if not v.type.is_super(v_update.type):
raise TypeError(
"An update must have a type compatible with "
"the original shared variable"
# Iterative depth-first traversal; recursion exceeds Python's stack on deep graphs
stack = [v]
while stack:
var = stack.pop()
if var in clone_d:
continue
owner = var.owner
if owner is not None:
if owner not in clone_d:
pending = [i for i in owner.inputs if i not in clone_d]
if pending:
stack.append(var)
stack.extend(reversed(pending))
continue
clone_node_and_cache(
owner,
clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
)
clone_d.setdefault(var, var)
continue
if isinstance(var, SharedVariable):
if var not in shared_inputs:
shared_inputs.append(var)
if var.default_update is not None:
# Check that var should not be excluded from the default
# updates list
if no_default_updates is False or (
isinstance(no_default_updates, list)
and var not in no_default_updates
):
# Do not use default_update if a "real" update was
# provided
if var not in update_d:
var_update = var.type.filter_variable(
var.default_update, allow_convert=False
)
update_d[v] = v_update
update_expr.append((v, v_update))
if not copy_inputs_over:
return clone_d.setdefault(v, v.clone())
else:
return clone_d.setdefault(v, v)
if not var.type.is_super(var_update.type):
raise TypeError(
"An update must have a type compatible with "
"the original shared variable"
)
update_d[var] = var_update
update_expr.append((var, var_update))
if not copy_inputs_over:
clone_d.setdefault(var, var.clone())
else:
clone_d.setdefault(var, var)
return clone_d[v]

# initialize the clone_d mapping with the replace dictionary
if replace is None:
Expand Down
7 changes: 7 additions & 0 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,15 @@ def numba_funcify_FunctionGraph(
fgraph: AbstractFunctionGraph,
node=None,
fgraph_name="numba_funcified_fgraph",
ofg_memo=None,
**kwargs,
):
# Memoize compiled OpFromGraph inner functions by Op, so repeated uses of
# equal OpFromGraphs (anywhere in the same compilation, including nested
# inner graphs) reuse a single compiled function
if ofg_memo is None:
ofg_memo = {}
kwargs["ofg_memo"] = ofg_memo
# Collect cache keys of every Op/Constant in the FunctionGraph
# so we can create a global cache key for the whole FunctionGraph
fgraph_can_be_cached = [True]
Expand Down
18 changes: 16 additions & 2 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,17 @@ def string_deepcopy(x):

@register_funcify_and_cache_key(OpFromGraph)
def numba_funcify_OpFromGraph(
op, node=None, mode=NUMBA.excluding("symbolic_op_recognition"), **kwargs
op,
node=None,
mode=NUMBA.excluding("symbolic_op_recognition"),
ofg_memo=None,
**kwargs,
):
_ = kwargs.pop("storage_map", None)

if ofg_memo is not None and op in ofg_memo:
return ofg_memo[op]

# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
Expand All @@ -70,7 +77,11 @@ def numba_funcify_OpFromGraph(
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
fgraph, squeeze_output=True, fgraph_name="numba_ofg", **kwargs
fgraph,
squeeze_output=True,
fgraph_name="numba_ofg",
ofg_memo=ofg_memo,
**kwargs,
)

if fgraph_cache_key is None:
Expand All @@ -86,6 +97,9 @@ def numba_funcify_OpFromGraph(
).encode()
).hexdigest()

if ofg_memo is not None:
ofg_memo[op] = (fgraph_fn, ofg_cache_key)

return fgraph_fn, ofg_cache_key


Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_destroyed_untraced_out_idxs.add(untraced_start + j)

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan"
op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo")
)

outer_in_names_to_vars = {
Expand Down
50 changes: 27 additions & 23 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,12 +847,12 @@ def _show_inner_graph(op):
new_prefix_child = prefix + " "
print("Inner graphs:", file=_file)

printed_inner_graphs_nodes = set()
printed_inner_graph_ops = set()
for ig_var in inner_graph_vars:
if ig_var.owner in printed_inner_graphs_nodes:
if ig_var.owner.op in printed_inner_graph_ops:
continue
else:
printed_inner_graphs_nodes.add(ig_var.owner)
printed_inner_graph_ops.add(ig_var.owner.op)
# This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`)
Expand Down Expand Up @@ -889,27 +889,31 @@ def _show_inner_graph(op):

print("", file=_file)

_debugprint(
ig_var,
prefix=prefix,
depth=depth,
done=done,
print_type=print_type,
print_shape=print_shape,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
assumption_tags=assumption_tags,
parent_node=ig_var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
is_inner_graph_header=True,
# Header line: the Op, then a single "[id A, B, ...]" listing every
# node whose inner graph is this one (printed once below), then its
# destroy/view maps. Equal Ops have identical inner graphs, so
# membership is grouped by Op equality. It must be computed here,
# not before the loop: nodes nested inside other inner graphs are
# only discovered while printing the bodies above. Output is
# streamed, so a node of an equal Op discovered after this header
# has printed cannot be added to it retroactively.
op = ig_var.owner.op
id_strs = [
_assign_id(node, used_ids, done, id_type, node.outputs[0])
# A multi-output node appears once per output var; dedup nodes.
for node in dict.fromkeys(
v.owner for v in inner_graph_vars if v.owner.op == op
)
]
tokens = [
s[4:-1] for s in id_strs if s.startswith("[id ") and s.endswith("]")
]
ids_str = f" [id {', '.join(tokens)}]" if tokens else ""
destroy_map_str = (
f" d={op.destroy_map}" if print_destroy_map and op.destroy_map else ""
)
view_map_str = f" v={op.view_map}" if print_view_map and op.view_map else ""
print(f"{op}{ids_str}{destroy_map_str}{view_map_str}", file=_file)

if print_fgraph_inputs:
for inp in inner_inputs:
Expand Down
4 changes: 3 additions & 1 deletion pytensor/tensor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,9 @@ def root_scalar(


class RootOp(ScipyVectorWrapperOp):
__props__ = ("method", "jac")
# These __props__ were wrong: they ignore the inner graph,
# making RootOps of different equations compare equal (and get merged)
# __props__ = ("method", "jac")

def __init__(
self,
Expand Down
16 changes: 16 additions & 0 deletions tests/compile/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import sys

import pytensor.tensor as pt
from pytensor.compile.rebuild import rebuild_collect_shared


def test_rebuild_collect_shared_deep_graph():
# Cloning must not recurse, or graphs deeper than the interpreter stack fail
x = pt.dscalar("x")
out = x
for i in range(sys.getrecursionlimit() + 500):
out = out + i

input_variables, cloned_outputs, (clone_d, *_) = rebuild_collect_shared([out], [x])
assert input_variables == [x]
assert cloned_outputs == [clone_d[out]]
4 changes: 3 additions & 1 deletion tests/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def make_node(self, input):


class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
# No __props__: an inner-graph Op's identity is its inner graph, which a
# props-based __eq__/__hash__ would ignore (collapsing distinct instances).
# Falling back to object identity is the correct default for this mock.

def __init__(self, inner_inputs, inner_outputs):
input_replacements = [
Expand Down
31 changes: 31 additions & 0 deletions tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,37 @@ def test_OpFromGraph():
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])


def test_ofg_inner_compilation_reused(monkeypatch):
from pytensor.link.numba.dispatch import compile_ops

n_inner_compiles = 0
inner_funcify = compile_ops.numba_funcify_and_cache_key

def counting_funcify(*args, **kwargs):
nonlocal n_inner_compiles
n_inner_compiles += 1
return inner_funcify(*args, **kwargs)

monkeypatch.setattr(compile_ops, "numba_funcify_and_cache_key", counting_funcify)

x = pt.vector("x")
inner = OpFromGraph([x], [pt.exp(x) + 1])
inner_equiv = OpFromGraph([x], [pt.exp(x) + 1])
outer = OpFromGraph([x], [inner(x) * 2])

y = pt.vector("y")
out = outer(inner(y)) + inner_equiv(y)
fn = function([y], out, mode="NUMBA")
# Only two distinct OpFromGraphs: `inner` (used directly, nested inside
# `outer`, and via the equivalent `inner_equiv`) and `outer`
assert n_inner_compiles == 2

y_test = np.array([0.0, 1.0], dtype=config.floatX)
inner_res = np.exp(y_test) + 1
expected = (np.exp(inner_res) + 1) * 2 + inner_res
np.testing.assert_allclose(fn(y_test), expected, rtol=1e-6)


@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
Expand Down
15 changes: 15 additions & 0 deletions tests/tensor/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,21 @@ def root_fn(x, a):
utt.verify_grad(root_fn, [x0, a_val], eps=1e-6)


def test_root_distinct_equations_not_merged():
# Regression test: RootOp __props__ ignored the inner graph, so RootOps of
# different equations compared equal and MergeOptimizer collapsed them
x = pt.scalar("x")
a = pt.scalar("a")

sq_root, _ = root(x**2 - a, x)
cube_root, _ = root(x**3 - a, x)

func = pytensor.function([x, a], [sq_root, cube_root])
sq_res, cube_res = func(1.5, 8.0)
np.testing.assert_allclose(sq_res, np.sqrt(8.0), rtol=1e-6)
np.testing.assert_allclose(cube_res, np.cbrt(8.0), rtol=1e-6)


def test_root_system_of_equations():
x = pt.tensor("x", shape=(None,))
a = pt.tensor("a", shape=(None,))
Expand Down
Loading
Loading