Rewrite ShapeFeature to not hold live variables#2056
Conversation
f80be94 to
aba080f
Compare
c37cb6b to
68e1643
Compare
|
Rewrite time on the asv experiment down: https://ricardov94.github.io/pymc-model-catalogue/experiments.html#base=shape_feature_pr2056_base&compare=shape_feature_pr2056
|
d6c4c55 to
2f32495
Compare
|
I added a small patch for OFG to cache the shape graph, with this, this PR lazy ShapeFeature and #2147 the issue that #2164 is trying to circumvent is already handled. There's a new benchmark test tracking this. It can still make sense to eagerly rewrite einsum -> core graph for 2 inputs, but in general we can't afford to have inner graph ops be a performance bottleneck, as that's the direction we are moving (see #2110 and #1221). There were several sources of exponential blow-up that we are addressing here. None of this is a hack, the old code was just dumb. CC @cetagostini |
c12a879 to
df78b91
Compare
06f09bc to
0d31a58
Compare
| if shape_feature is None: | ||
| shape_feature = ShapeFeature() | ||
|
|
||
| core_shape = [ |
There was a problem hiding this comment.
@jessegrabowski I took a step back and instead don't try too fancy to get a super duper shape graph. In practice we get shape_i(node.input[i]), that may be simplifiable into shape_i(root_input) or root_input (if it's a dim), but that introduced the whole cycles concern. a shape_i(root_input) is as cheap as shape_i(node.input[i]).
We still cleanup the graph with restricted canonicalization, but shouldn't be possible to add cycles because all the variables we are using were already inputs to the pre-existing node.
If we find the need to have better optmized core_shape graphs the solution is to 1) Add them in the graph before inplace (maybe as dangling outputs) or fancy/expensive ways of trying to dealias a graph that you showed wasn't trivial to get right to my chagrin.
We are still missing a good history for how to do "side graph inference" in our pipeline I think.
0a8f5e7 to
041b8f2
Compare
jessegrabowski
left a comment
There was a problem hiding this comment.
nitpicks abound! I really encourage descriptive variable names but I'm not going to go to the mats for it.
|
|
||
| frozen = FrozenFunctionGraph([*inner_inputs, *shape_i_vars], flat_shapes) | ||
| self._inner_shape_template = template | ||
| self._inner_shape_frozen = frozen |
There was a problem hiding this comment.
oh actually we do freeze already, why not use this for equality checking above?
There was a problem hiding this comment.
Not sure what you meant? What equality checking above? This is freezing the shape graph, we never did this anywhere alse?
041b8f2 to
6588585
Compare
`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join` and batched-`Blockwise` as protected client ops, but not `Subtensor`. `local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into `alloc(val[...], *new_shape)` — preserving the Alloc structure that downstream rewrites like `local_blockwise_alloc_inputs` depend on. Folding the Alloc here short-circuited that lift and produced broadcast-equivalent `Constant` matrices whose batch dim was no longer type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the surrounding `Blockwise(Reshape)`. Surfaced by the lazy-kernel `ShapeFeature` (which resolves `Subtensor(Shape(out), const)` to a scalar `Constant` earlier and makes more upstream Allocs constant-foldable), but the fix belongs here — the protection was too narrow.
Breaking API change: the `fgraph` argument was unused by every in-tree `infer_shape` implementation. Removing it makes `infer_shape` a pure function of `(node, input_shapes)`, simpler to call from outside an fgraph context (e.g. ShapeFeature's lazy kernel build) and tighter as a contract. External Ops with custom `infer_shape(self, fgraph, node, input_shapes)` must drop the `fgraph` parameter.
eea4eb9 to
24dfe0f
Compare
Add `FrozenFunctionGraph.from_structural_inputs`, the structural-matching dual of `bind`: inputs may be interior expressions, matched against the outputs by structure (via interning) and rewired to the input boundary. `OpFromGraph.infer_shape` now uses it to express the inner-output shapes as a frozen function of the inner inputs plus one slot per input dimension, replacing the manual blocker list and `shape_i_keys` bookkeeping. The dense, positional layout lets `bind` fill slots straight from the caller's shapes.
local_subtensor_shape_constant only folded broadcastable dims, deferring the general case to the ShapeFeature, which is not present everywhere canonicalize runs (e.g. rewrite_subgraph).
Core shapes built from recursive shape inference can read variables that inplace rewrites destroy, making the wrapper node unschedulable. Expressions that read only the node's own inputs, through constants and fresh Shape_i applies, never conflict with the surrounding graph. They are canonicalized in isolation with the new rewrite_subgraph utility.
The shape-bearing test was duck-typed as `hasattr(var.type, "ndim")` (carried over from the old ShapeFeature API). Use the canonical `isinstance(var.type, HasShape)` instead: every shape-bearing type (Tensor, Scalar, Sparse, XTensor) subclasses HasShape, so it is equivalent, and on the common tensor/scalar inputs it short-circuits the MRO walk rather than doing a full getattr.
24dfe0f to
cb4b367
Compare

Closes pymc-devs/pymc-extras#673
ShapeFeature reintroducing variables we have lowered/rewritten away is no-bueno
Also the eager call to infer_shape can be very wasteful and grow up exponential as seen in #2164
Details
Replace the eager per-variable dict (shape_of, shape_of_reverse_index, scheduled) with a lazy FrozenFunctionGraph-based shape kernel cache. For each Apply, a kernel built from dummy clones of node.inputs is stored in self._cache[node] and materialized against today's live inputs on demand via a custom frozen-graph walker (graph_replace would mutate globally-interned FrozenApply inputs).The kernel holds only NominalVariables and Constants, so no live
variable can leak between tests or across rewrites, eliminating by
construction the stale-XRV class of bugs.
Back-compat surface (_LazyShapeTuple, _ShapeOfProxy, update_shape,
shape_ir, init_r) is retained and marked as temporary. A regression
test for the stale-XRV scenario replaces the prior xfail.
shape_of_variables switches to builders.infer_shape so it returns to
scalar-dim inputs instead of allocating per-input arrays.
local_track_shape_i no longer depends on the deleted scheduled dict;
it rewrites Shape_i(v, i) to get_shape(v, i) whenever the kernel
produces something other than the trivial fallback.
on_change_input carries r's inferred shape onto new_r as an override
when new_r's Op has no infer_shape, preserving the legacy behavior
where a well-inferred shape survives through a replacement with an
opaque op.
Benchmarks (cxx enabled):