Skip to content

Allow freezing of FunctionGraph for hashing#1908

Open
jessegrabowski wants to merge 10 commits intopymc-devs:v3from
jessegrabowski:hashable-inner-graphs
Open

Allow freezing of FunctionGraph for hashing#1908
jessegrabowski wants to merge 10 commits intopymc-devs:v3from
jessegrabowski:hashable-inner-graphs

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

Closes #1606

LLM disclosure: this PR made heavy use of Claude in the planning and first cut stages, though I was heavily involved. Still, the code should be subject to extra scrutiny as a result.

The purpose of the PR is to refactor Ops with inner graphs to allow comparison. The linked issue has an exhaustive discussion of the factors at play. There was an attempt in the aesara days to attack this, but it was perhaps too aggressive: it cons-hashed all Apply nodes, which necessitated changes across the codebase. @ricardoV94 suggested a weakref dict approach for subgraphs. This is implemented at the Op level. The plan is for Ops that have inner graphs (Composite, ScalarLoop, Scan, OpFromGraph, etc) to have a _cache class attribute, and implement the op-specific logic for caching, pickling, unpickling, etc. It didn't look super generalizable to me at first blush, but we can argue about it maybe.

Changes to FunctionGraph:

  • FunctionGraph now has a method freeze that returns a FrozenFunctionGraph.
  • The FrozenFunctionGraph does cons-hashing of Apply nodes within its scope only
  • It generates a hash based on its inner graph
  • Two FrozenFunctionGraphs with the same inner graph with evaluate to equal, but their Apply nodes won't be references to the same objects (this is the "conservatism" of my approach)

Specific implementation details:

  • The structural_hash of a FrozenFunctionGraph is built from a list of 3-tuples: (name, type, inputs), plus the outputs. For constants, inputs is replaced with the hash of the input data.
  • Equality between FrozenFunctionGraphs is done by comparing hashes, then falling back to equal_computation if the hash misses.

A consequence of the cons-hashing in this approach is that the inner graph is de-duplicated when we call fg.freeze(). So a MergeOptimizer pass is no longer required. Usage is demonstrated on the Composite Op. If we like the approach I can move forward with refactoring other Ops, but I wanted to stop here and discuss the approach.

Code example:

import pytensor.tensor as pt
import pytensor

a, b, c, d = pt.dscalars('a', 'b', 'c', 'd')
eq1 = pt.sin(a) * b ** 2
eq2 = pt.sin(c) * d ** 2

with pytensor.config.change_flags(optimizer_verbose=True):
    f = pytensor.function([a, b, c, d], [eq1, eq2])

f.dprint()

Result:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A] 1
 ├─ a [id B]
 └─ b [id C]
Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D] 0
 ├─ c [id E]
 └─ d [id F]

Inner graphs:

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id A]
 ← mul [id G]
    ├─ sin [id H]
    │  └─ *0-<float64> [id I]
    └─ sqr [id J]
       └─ *1-<float64> [id K]

Composite{(sin(*0-<float64>) * sqr(*1-<float64>))} [id D]
 ← mul [id G]
    └─ ···

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you not go all out?

If you already deduplicate and do internal hash-cons you are one step away from getting hashing for free across different FunctionGraphs. Just do the hash-cons globally. Then FrozenFunctionGrahp([x, y], [foo(x, y)] is equal to another functiongraph if and only if fgraph.outputs == other_fgraph.outputs. No need for recursive hashing or expensive equal_computations.

As it stands you are not doing much better sneaking a default MergeOptimizer at __init__ and adding a FunctionGraph class that has no replace mode.

And cheap hashing/ equality is not just a nice to have, it's really valuable to not slow down compilation. In some of my benchmarks on previous work, some graphs could spend inordinate time on equality checks.

Comments regardless of whether we go:

  • Don't create FrozenFunctionGraph as a subclass of FrozenGraph, let's push the general principle, shared abstract classes, no-subclass of actually realized objects. Then you don't need check_frozen , the methods just don't exist for the frozen subclass.
  • You could create a frozenApply that uses tuple for input/outputs instead of list. That will help ensuring the immutability because all our current rewrite machinery works on the idea of overriding entries in those lists. Accidentally trying to mutate a graph would 99% fail there.

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to look good, how are you feeling about it?

Notes:

  • Add a FrozenFunctionGraph.unfreeze(), that yields a FunctionGraph?
  • Really try to avoid the FrozenConstant stuff
  • Ops with inner graph (at least the ones you touched now) should only have a FrozenFunctionGraph internally (not a mutable one as well). Maybe that's already the case.

We need some follow-up issues open:

  • Optimizing OpFromGraph: There should be an explicit rewrite that creates a new OpFromGraph with its updated frozen graph, (so it is also reflected immediately in dprint). We should never do any further rewrites of the internal fgraph during compilation.
  • Scan/Minimize/Root: Use the new FrozenFunctionGraph as well. This should immediately address #1601
  • When compiling OpFromGraph in jitted contexts we should try to avoid recreating inner numba/jax functions when the same OFG is compiled multiple times in a function, this will likely speedup compilation. In the C-backend that already happens due to the caching of _fn. That's how we can deliver on the promised compilations speedups and it's specially relevant for a library like pytensor-ml that may want to chains hundreds of the same "LayerOp"s in sequence

@@ -4140,38 +4116,17 @@ def prepare_node(self, node, storage_map, compute_map, impl):
def __eq__(self, other):
if self is other:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we have regular __props__ based equality/hashing now?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski this still stands. With proper fgraph equality, we could have these inner graph Ops behave like other Ops based on __props__ (simpler mental model for devs). __props__ = ("fgraph",) (and whatever else not in the fgraph that influences behavior) ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch 2 times, most recently from 78ee1a9 to eda51d2 Compare March 8, 2026 19:18
@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Mar 8, 2026

I left some comments as I checked the changes. I need to think/discuss a bit about the spec thing, and the desire to have a consistent hashing across runtimes. If you remove that the complexity of this PR drops quite a bit, but maybe this is also fine.

Can you confirm this was only needed for the C-backend, and that it would also work if whatever relies on that called something like __stable_hash__ instead of __hash__, that does the fingerprint / spec thing?

Besides that this PR look amazing, and it's a game changer to working with inner graph ops. We really need those to work well

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from eda51d2 to 7202ca3 Compare March 9, 2026 00:15
@jessegrabowski
Copy link
Copy Markdown
Member Author

I removed the spec stuff and simplified the PR down somewhat.

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch 3 times, most recently from 4a7bea8 to 445731f Compare March 9, 2026 00:48
out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign]
instance.tag = Scratchpad()
cls._cache[cache_key] = instance
return instance
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the frozenapply to be hash-consed? Isn't it enough if the input/output variables are? Wondering if we can remove some extra code that way. The Apply doesn't do much anyway

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tied up in the current identity-based equality scheme. If we remove the FrozenApply interning, FrozenFunctionGraph.init will create new Apply nodes with new output Variables, so we lose output1 is output2 and thus fg1 == fg2.

That's not to say we couldn't move the equality check responsibility inside FFG, but it's a slightly different design.

assert op1 != op_different

# inline flag participates in equality
op_inline = OpFromGraph([x, y], [e], inline=True)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably correct, but it's something that we have to think about. Some Ops have properties that affect their behavior in pytensor but not the computational meaning.

Those in theory should not be part of __props__ and affect equality. This would allow MergeOptimizer to merge the nodes with the same inputs but different OFG, which I guess is what we would want for the final compiled graph (e.g., it doesn't matter if node1 has a different gradient than node2 when you compile it at the end).

But on the other hand we don't want to confuse the two Ops, because when we "freeze -> unfreeze" for example, we don't want to lose those attributes.

We need to work on this in the future. There are different degrees of "equality" we need for different things. And maybe custom __eq__ even in the presence of __props__ is what we need, but I don't thing MergeOptimizer looks at __props__ specifically.

# OFG is hashable, and different OFGs have different hashes
assert hash(op1) != hash(op_inline)

def test_equality_shared_variables(self):
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This special behavior of shared variables is something I want to get rid of already for v3, but fine to test here as it's still a thing

ofg_nodes = [n for n in fg.toposort() if isinstance(n.op, OpFromGraph)]
assert len(ofg_nodes) == 1

# Different inputs are different graphs, so both nodes survive
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if in the rebuild_collect_shared that's used at the beginning of the function compilation, we will still merge the Op (if not the node ofc). Because that would be nice, only one compilation instead of 2. Just a curiosity, not something that needs to be tested here

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing. Small notes throughout, just want us to think on whether we also need the FrozenApply to be hash-consed. Maybe yes, I dunno why I'm drawing a distinction (I think I didn't like the extra pickling code needed).

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Mar 9, 2026

I think with this PR we'll stop seeing what was the actual optimized inner graph in the compiled function? Something we should follow up with, optimizing inner graph should be an explicit rewrite, not something that happens at make_node / dispatch time

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch 3 times, most recently from c7bce17 to f404421 Compare March 10, 2026 00:42
@jessegrabowski
Copy link
Copy Markdown
Member Author

The remaining test failures on this PR are due to it changing the dprint of functions with inner graphs. In particular, we use the i0, i1, add -> o0 notations:

E         Full diff:
E           [
E         -     'Composite{(i0 + (i1 - i2))} 4',
E         +     'Composite{(*0-<float64> + (*1-<float64> - *2-<float64>))} 4',
E               '├─ A',
E               '├─ ExpandDims{axis=0} v={0: [0]} 3',
E               '│  └─ CGemv{inplace} d={0: [0]} 2',
E               "│     ├─ AllocEmpty{dtype='float64'} 1",
E               '│     │  └─ Shape_i{0} 0',
E               '│     │     └─ B',
E               '│     ├─ 1.0',
E               '│     ├─ B',
E               '│     ├─ <Vector(float64, shape=(?,))>',
E               '│     └─ 0.0',
E               '└─ D',
E               '',
E               'Inner graphs:',
E               '',
E         -     'Composite{(i0 + (i1 - i2))}',
E         +     'Composite{(*0-<float64> + (*1-<float64> - *2-<float64>))}',
E         -     "← add 'o0'",
E         ?     ^     - ----
E         +     '← add',
E         ?     ^
E         -     '├─ i0',
E         +     '├─ *0-<float64>',
E               '└─ sub',
E         -     '├─ i1',
E         -     '└─ i2',
E         +     '├─ *1-<float64>',
E         +     '└─ *2-<float64>',
E               '',
E           ]

Do you want me to fight to result it or just update the expectations of the test

@ricardoV94
Copy link
Copy Markdown
Member

what about keeping that i0/o0 in the frozen graphs anyway? It's nicer?

Otherwise, yes we can always change print test results. They are there to make us decide consciously


def __init__(self, op, inputs, output_types):
# All initialization is done in __new__
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it in __new__ and not __init__?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we want to return cached variables for the outputs. new is run before an object exists, so we have time to do the output injection. If we put the setup in init, on cache hit:

  • call new
  • put the cached outputs into self.outputs
  • call init
  • overwrite the cached output with new outputs, breaking the identity equality

Comment on lines +909 to +912
# Give each output Variable a __reduce__ that resolves to the
# canonical output on unpickle, avoiding fresh Variable objects.
for out in instance.outputs:
out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ELI5?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FrozenApply is interned, but it's output variables are not. When we unpickle stuff, python will make a new object so that a is not unpickle(pickle(a)). Our variable equality is identity based, so now a == unpickle(pickle(a)) fails. FrozenFunctionGraph also depends on comparing outputs by equality, so without this we end up with fg1 != fg2

The patch makes it so that when unpickling, we go look for the (interned) "canonical" version of the output (e.g. cached_apply.outputs[0]) and use that. Note that the variable itself isn't interned, just the apply. So this patch is basically letting pickle know about and use this relationship, rather than creating orphan duplicates.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the apply is interned the variables are also (and vice versa). I assumed we were going to intern the variables because you may have variables without apply but not the other way around

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intern the variables, but only transitively. We have to go reach into the FrozenApply to get them.

return __reduce_ex__


class FrozenApply(Apply):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still seems like we're getting a lot of complexity for this FrozenApply thing that is not present in regular interned Variables? Is it due to pickling? Why is it trickier?

Copy link
Copy Markdown
Member Author

@jessegrabowski jessegrabowski Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of FrozenApply is to allow equality-based comparison of FrozenFunctionGraph outputs. I went through an intermediate plan that used a tuple spec of (op, input_refs, n_ouputs) per node that would also have worked, but then we're maintaining that machinery. If we want to be able to do fg1 == fg2, we have to have an abstraction somewhere that allows robust serialization/hashing of nodes (including constants, which have been a repeated challenge during this PR)

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that the variables being hash-consed already achieves that, you could even have stuck with regular Apply objects holding the hash consed variables as inputs/outputs.

I suggested a frozen apply just so it would use tuples and reduce the risk of accidentally mutating them, but they never seemed necessary (to me) for the goal of hash/equality

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's not necessary per se, but there has to be some kind of hashable topological representation somewhere so we can rebuild the graph. The tuple-spec or "fingerprint" was one approach that interns the variables directly, and just the whole graph topology in a list of nested tuples. I settled on FrozenApply because it also encodes the same topology but in a representation that feels more "pytensor native". The downside is that it adds this intermediate object that we have to go through to get the variables themselves.

@@ -4140,38 +4116,17 @@ def prepare_node(self, node, storage_map, compute_map, impl):
def __eq__(self, other):
if self is other:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder

def clone(self) -> "Scan":
res = copy(self)
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
res.fgraph = res.fgraph.clone(clone_inner_graphs=True) # type: ignore[attr-defined]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to clone fgraph, it's frozen. Clone method can be all removed?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scan is weird because all of the rewrites end up using clone_replace on the fgraph. Currently, I'm storing the fgraph twice: once frozen (for equality/hashing) and once unfrozen (for rewrites). If we go down to just the frozen, we'll have to rewrite all the rewrites to unfreeze -> mutate -> freeze, which we then need to make sure devs know about.

And alternative would be to have FrozenApply.clone or clone_with_new_inputs unfreeze, but that seems like an anti-pattern because the method is called "clone" but it actually mutates the object by unfreezing it.

Copy link
Copy Markdown
Member Author

@jessegrabowski jessegrabowski Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started looking into ripping out the fgraph from scan and just using frozen_fgraph, but it's a much larger refactor that implicates all the scan rewrites. I want to handle it in a follow up PR.

For now, scan is using both: frozen_fgraph for equality checks and fgraph as a mutable representation

@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from e415091 to 82ae0bb Compare March 30, 2026 03:32
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from 82ae0bb to 8c5a4fb Compare March 30, 2026 03:34
@jessegrabowski jessegrabowski changed the base branch from main to v3 March 30, 2026 03:35
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from 8c5a4fb to 25c9561 Compare March 30, 2026 03:37
@jessegrabowski jessegrabowski marked this pull request as ready for review March 30, 2026 03:38
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from 5d85d71 to c67105a Compare March 30, 2026 03:40
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from c67105a to 1277ecd Compare March 31, 2026 03:48
@jessegrabowski jessegrabowski force-pushed the hashable-inner-graphs branch from 1277ecd to 221e801 Compare March 31, 2026 11:34
@jessegrabowski
Copy link
Copy Markdown
Member Author

Failure seems unrelated but you never know. I don't reproduce it locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Equality of Ops with InnerGraph

2 participants