Skip to content

Rewrite ShapeFeature to not hold live variables#2056

Merged
ricardoV94 merged 10 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature
Jun 14, 2026
Merged

Rewrite ShapeFeature to not hold live variables#2056
ricardoV94 merged 10 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Apr 17, 2026

Copy link
Copy Markdown
Member

Closes pymc-devs/pymc-extras#673

ShapeFeature reintroducing variables we have lowered/rewritten away is no-bueno

Also the eager call to infer_shape can be very wasteful and grow up exponential as seen in #2164

Details Replace the eager per-variable dict (shape_of, shape_of_reverse_index, scheduled) with a lazy FrozenFunctionGraph-based shape kernel cache. For each Apply, a kernel built from dummy clones of node.inputs is stored in self._cache[node] and materialized against today's live inputs on demand via a custom frozen-graph walker (graph_replace would mutate globally-interned FrozenApply inputs).

The kernel holds only NominalVariables and Constants, so no live
variable can leak between tests or across rewrites, eliminating by
construction the stale-XRV class of bugs.

Back-compat surface (_LazyShapeTuple, _ShapeOfProxy, update_shape,
shape_ir, init_r) is retained and marked as temporary. A regression
test for the stale-XRV scenario replaces the prior xfail.

shape_of_variables switches to builders.infer_shape so it returns to
scalar-dim inputs instead of allocating per-input arrays.

local_track_shape_i no longer depends on the deleted scheduled dict;
it rewrites Shape_i(v, i) to get_shape(v, i) whenever the kernel
produces something other than the trivial fallback.

on_change_input carries r's inferred shape onto new_r as an override
when new_r's Op has no infer_shape, preserving the legacy behavior
where a well-inferred shape survives through a replacement with an
opaque op.

Benchmarks (cxx enabled):

  • radon_repeat 0.78s -> 0.55s (-30%)
  • radon_variants (8) 7.9s -> 7.2s ( -9%)
  • fusion_large 0.22s -> 0.22s (noise)
  • fusion_deep 13ms -> 13ms (noise)

@ricardoV94

Copy link
Copy Markdown
Member Author

@ricardoV94

Copy link
Copy Markdown
Member Author

I added a small patch for OFG to cache the shape graph, with this, this PR lazy ShapeFeature and #2147 the issue that #2164 is trying to circumvent is already handled.

Before PR
--------------------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------------------
Name (time in s)                                    Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]      4.6451 (1.0)       5.0273 (1.0)       4.7674 (1.0)      0.1500 (1.0)       4.7223 (1.0)      0.1310 (1.0)           1;1  0.2098 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     70.5293 (15.18)    75.3482 (14.99)    73.3600 (15.39)    2.3004 (15.34)    73.7812 (15.62)    3.7426 (28.56)         1;0  0.0136 (0.06)          4           1
test_xtensor_attention_rewrite_benchmark[4]     HEAT DEATH OF THE UNIVERSE
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After lazy Shape Feature
----------------------------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------------------------
Name (time in ms)                                      Min                   Max                  Mean             StdDev                Median                 IQR            Outliers     OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]       604.5540 (1.0)        734.8644 (1.0)        678.0195 (1.0)      58.0627 (2.44)       702.4159 (1.0)      101.5709 (2.87)          1;0  1.4749 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     1,238.5549 (2.05)     1,337.9649 (1.82)     1,266.2889 (1.87)     41.0227 (1.72)     1,248.9357 (1.78)      38.6898 (1.09)          1;0  0.7897 (0.54)          5           1
test_xtensor_attention_rewrite_benchmark[4]     2,128.2901 (3.52)     2,186.8365 (2.98)     2,164.0254 (3.19)     23.8237 (1.0)      2,173.0298 (3.09)      35.4256 (1.0)           1;0  0.4621 (0.31)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After shape graph cache:
------------------------------------------------------------------------------------------------- benchmark: 3 tests ------------------------------------------------------------------------------------------------
Name (time in ms)                                    Min                 Max                Mean             StdDev              Median                 IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]     288.4377 (1.0)      422.6271 (1.0)      327.9444 (1.0)      53.7626 (1.0)      310.3342 (1.0)       38.3331 (1.0)           1;1  3.0493 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     431.9948 (1.50)     601.8571 (1.42)     505.5498 (1.54)     78.9754 (1.47)     457.3631 (1.47)     135.0387 (3.52)          1;0  1.9780 (0.65)          5           1
test_xtensor_attention_rewrite_benchmark[4]     614.3723 (2.13)     783.8840 (1.85)     714.3603 (2.18)     82.6776 (1.54)     756.9681 (2.44)     151.6929 (3.96)          1;0  1.3999 (0.46)          5           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

There's a new benchmark test tracking this. It can still make sense to eagerly rewrite einsum -> core graph for 2 inputs, but in general we can't afford to have inner graph ops be a performance bottleneck, as that's the direction we are moving (see #2110 and #1221).

There were several sources of exponential blow-up that we are addressing here. None of this is a hack, the old code was just dumb.

CC @cetagostini

@ricardoV94 ricardoV94 force-pushed the shape_feature branch 3 times, most recently from c12a879 to df78b91 Compare May 28, 2026 20:12
if shape_feature is None:
shape_feature = ShapeFeature()

core_shape = [

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski I took a step back and instead don't try too fancy to get a super duper shape graph. In practice we get shape_i(node.input[i]), that may be simplifiable into shape_i(root_input) or root_input (if it's a dim), but that introduced the whole cycles concern. a shape_i(root_input) is as cheap as shape_i(node.input[i]).

We still cleanup the graph with restricted canonicalization, but shouldn't be possible to add cycles because all the variables we are using were already inputs to the pre-existing node.

If we find the need to have better optmized core_shape graphs the solution is to 1) Add them in the graph before inplace (maybe as dangling outputs) or fancy/expensive ways of trying to dealias a graph that you showed wasn't trivial to get right to my chagrin.

We are still missing a good history for how to do "side graph inference" in our pipeline I think.

@ricardoV94 ricardoV94 force-pushed the shape_feature branch 2 times, most recently from 0a8f5e7 to 041b8f2 Compare June 12, 2026 08:42
Comment thread pytensor/tensor/rewriting/shape.py Outdated
Comment thread pytensor/tensor/rewriting/shape.py Outdated

@jessegrabowski jessegrabowski left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicks abound! I really encourage descriptive variable names but I'm not going to go to the mats for it.


frozen = FrozenFunctionGraph([*inner_inputs, *shape_i_vars], flat_shapes)
self._inner_shape_template = template
self._inner_shape_frozen = frozen

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually we do freeze already, why not use this for equality checking above?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you meant? What equality checking above? This is freezing the shape graph, we never did this anywhere alse?

Comment thread pytensor/compile/builders.py Outdated
`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join`
and batched-`Blockwise` as protected client ops, but not `Subtensor`.
`local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into
`alloc(val[...], *new_shape)` — preserving the Alloc structure that
downstream rewrites like `local_blockwise_alloc_inputs` depend on.
Folding the Alloc here short-circuited that lift and produced
broadcast-equivalent `Constant` matrices whose batch dim was no longer
type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the
surrounding `Blockwise(Reshape)`.

Surfaced by the lazy-kernel `ShapeFeature` (which resolves
`Subtensor(Shape(out), const)` to a scalar `Constant` earlier and
makes more upstream Allocs constant-foldable), but the fix belongs
here — the protection was too narrow.
Breaking API change: the `fgraph` argument was unused by every
in-tree `infer_shape` implementation. Removing it makes
`infer_shape` a pure function of `(node, input_shapes)`, simpler
to call from outside an fgraph context (e.g. ShapeFeature's lazy
kernel build) and tighter as a contract.

External Ops with custom `infer_shape(self, fgraph, node, input_shapes)`
must drop the `fgraph` parameter.
Add `FrozenFunctionGraph.from_structural_inputs`, the structural-matching
dual of `bind`: inputs may be interior expressions, matched against the
outputs by structure (via interning) and rewired to the input boundary.

`OpFromGraph.infer_shape` now uses it to express the inner-output shapes
as a frozen function of the inner inputs plus one slot per input dimension,
replacing the manual blocker list and `shape_i_keys` bookkeeping. The dense,
positional layout lets `bind` fill slots straight from the caller's shapes.
local_subtensor_shape_constant only folded broadcastable dims, deferring
the general case to the ShapeFeature, which is not present everywhere
canonicalize runs (e.g. rewrite_subgraph).
Core shapes built from recursive shape inference can read variables
that inplace rewrites destroy, making the wrapper node unschedulable.
Expressions that read only the node's own inputs, through constants
and fresh Shape_i applies, never conflict with the surrounding graph.
They are canonicalized in isolation with the new rewrite_subgraph
utility.
The shape-bearing test was duck-typed as `hasattr(var.type, "ndim")`
(carried over from the old ShapeFeature API). Use the canonical
`isinstance(var.type, HasShape)` instead: every shape-bearing type
(Tensor, Scalar, Sparse, XTensor) subclasses HasShape, so it is
equivalent, and on the common tensor/scalar inputs it short-circuits
the MRO walk rather than doing a full getattr.
@ricardoV94 ricardoV94 merged commit 8ed8df0 into pymc-devs:main Jun 14, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the shape_feature branch June 14, 2026 22:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Prior.create_variable(xdist=True) fails compile_logp for centered priors with nested Prior parameters that have dims

2 participants