From 71cb4e6b0832e156e6bc3c8a71abfbf6c143df31 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 17:30:48 +0200 Subject: [PATCH 01/10] Fix FrozenFunctionGraph.bind for constant outputs --- pytensor/graph/fg.py | 2 +- tests/graph/test_fg.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index a785958f4e..4b4250283f 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1099,7 +1099,7 @@ def bind( [o.type() for o in node.outputs], ) memo.update(zip(node.outputs, new_node.outputs)) - return [memo[out] for out in self.outputs] + return [out if isinstance(out, Constant) else memo[out] for out in self.outputs] def unfreeze(self) -> "FunctionGraph": """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 2d33112373..cd91fd8ca4 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -988,3 +988,14 @@ def test_value_dependent_output_type_collision(self): # ``bind`` reproduces each graph's own output type, not a collided one assert ffg1.bind([x, s1])[0].type.shape == (2, 3) assert ffg2.bind([x, s2])[0].type.shape == (3, 2) + + def test_bind_constant_output(self): + """bind must handle constants that appear directly as outputs.""" + x = float64("x") + c = ScalarConstant(float64, 42.0) + ffg = FunctionGraph([x], [add(x, c), c]).freeze() + + y = float64("y") + bound = ffg.bind({ffg.inputs[0]: y}) + assert len(bound) == 2 + assert bound[1] is c From 0cc26b2c8ec0da13765accb0c4f2feca2e47065c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 29 Apr 2026 19:32:53 +0200 Subject: [PATCH 02/10] Don't constant-fold `Alloc` consumed by `Subtensor` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join` and batched-`Blockwise` as protected client ops, but not `Subtensor`. `local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into `alloc(val[...], *new_shape)` — preserving the Alloc structure that downstream rewrites like `local_blockwise_alloc_inputs` depend on. Folding the Alloc here short-circuited that lift and produced broadcast-equivalent `Constant` matrices whose batch dim was no longer type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the surrounding `Blockwise(Reshape)`. Surfaced by the lazy-kernel `ShapeFeature` (which resolves `Subtensor(Shape(out), const)` to a scalar `Constant` earlier and makes more upstream Allocs constant-foldable), but the fix belongs here — the protection was too narrow. --- pytensor/tensor/basic.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..6f03bc2ed4 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1762,14 +1762,32 @@ def do_constant_folding(self, fgraph, node): if not clients: return False + from pytensor.tensor.blas import Gemv, Ger + from pytensor.tensor.blas_c import CGemv, CGer + from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + Subtensor, + ) + for client, idx in clients: client_op = client.op if isinstance(client_op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False - # Op's through which Alloc can be lifted - elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join): + # Op's through which Alloc can be lifted. ``Subtensor`` is + # included because ``local_subtensor_of_alloc`` rewrites + # ``alloc(val, *shape)[idx]`` into ``alloc(val[...], *new_shape)``, + # preserving the Alloc structure that downstream rewrites + # (e.g. ``local_blockwise_alloc_inputs``) rely on. Folding the + # Alloc here would short-circuit that lift and produce a + # broadcast-equivalent constant whose batch dim is no longer + # type-broadcastable. + elif isinstance( + client_op, Elemwise | DimShuffle | Alloc | Join | Subtensor + ): return False # Same for Blockwise, unless it has no batch_dims elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client): @@ -1779,13 +1797,13 @@ def do_constant_folding(self, fgraph, node): idx == 0 and isinstance( client_op, - pytensor.tensor.subtensor.IncSubtensor - | pytensor.tensor.subtensor.AdvancedIncSubtensor1 - | pytensor.tensor.subtensor.AdvancedIncSubtensor - | pytensor.tensor.blas.Gemv - | pytensor.tensor.blas_c.CGemv - | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer, + IncSubtensor + | AdvancedIncSubtensor1 + | AdvancedIncSubtensor + | Gemv + | CGemv + | Ger + | CGer, ) ): # Ops that will work inplace on the Alloc. So if they From 137d2d255d7232707b956c7783219cda2cea7987 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 30 Apr 2026 23:38:05 +0200 Subject: [PATCH 03/10] Drop `fgraph` parameter from `Op.infer_shape` signature Breaking API change: the `fgraph` argument was unused by every in-tree `infer_shape` implementation. Removing it makes `infer_shape` a pure function of `(node, input_shapes)`, simpler to call from outside an fgraph context (e.g. ShapeFeature's lazy kernel build) and tighter as a contract. External Ops with custom `infer_shape(self, fgraph, node, input_shapes)` must drop the `fgraph` parameter. --- pytensor/assumptions/specify.py | 2 +- pytensor/breakpoint.py | 2 +- pytensor/compile/builders.py | 2 +- pytensor/compile/ops.py | 10 +++--- pytensor/graph/op.py | 19 ++++++++++ pytensor/ifelse.py | 2 +- pytensor/raise_op.py | 2 +- pytensor/scan/op.py | 2 +- pytensor/sparse/basic.py | 36 +++++++++---------- pytensor/sparse/math.py | 32 ++++++++--------- pytensor/sparse/rewriting.py | 2 +- pytensor/tensor/basic.py | 28 +++++++-------- pytensor/tensor/blas.py | 12 +++---- pytensor/tensor/blockwise.py | 10 ++---- pytensor/tensor/elemwise.py | 6 ++-- pytensor/tensor/extra_ops.py | 18 +++++----- pytensor/tensor/fourier.py | 2 +- pytensor/tensor/linalg/constructors.py | 2 +- .../tensor/linalg/decomposition/cholesky.py | 2 +- pytensor/tensor/linalg/decomposition/eigen.py | 6 ++-- pytensor/tensor/linalg/decomposition/lu.py | 4 +-- pytensor/tensor/linalg/decomposition/qr.py | 2 +- pytensor/tensor/linalg/decomposition/schur.py | 4 +-- pytensor/tensor/linalg/decomposition/svd.py | 2 +- pytensor/tensor/linalg/inverse.py | 6 ++-- pytensor/tensor/linalg/products.py | 2 +- pytensor/tensor/linalg/solvers/core.py | 2 +- .../tensor/linalg/solvers/linear_control.py | 2 +- pytensor/tensor/linalg/summary.py | 4 +-- pytensor/tensor/math.py | 4 +-- pytensor/tensor/random/op.py | 2 +- pytensor/tensor/reshape.py | 4 +-- pytensor/tensor/rewriting/numba.py | 3 +- pytensor/tensor/rewriting/shape.py | 10 +++--- pytensor/tensor/shape.py | 8 ++--- pytensor/tensor/signal/conv.py | 2 +- pytensor/tensor/sort.py | 4 +-- pytensor/tensor/subtensor.py | 12 +++---- pytensor/xtensor/basic.py | 2 +- tests/compile/test_ops.py | 2 +- tests/sparse/test_basic.py | 2 +- tests/tensor/rewriting/test_shape.py | 4 +-- tests/tensor/test_blockwise.py | 4 +-- tests/tensor/test_elemwise.py | 8 ++--- 44 files changed, 152 insertions(+), 144 deletions(-) diff --git a/pytensor/assumptions/specify.py b/pytensor/assumptions/specify.py index ebf2282c89..201e7d6282 100644 --- a/pytensor/assumptions/specify.py +++ b/pytensor/assumptions/specify.py @@ -39,7 +39,7 @@ def make_node(self, x): out = x.type() return Apply(self, [x], [out]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback( diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index f9c74950dc..c2913cfea1 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage): def pullback(self, inputs, outputs, output_gradients): return [disconnected_type(), *output_gradients] - def infer_shape(self, fgraph, inputs, input_shapes): + def infer_shape(self, inputs, input_shapes): # Return the shape of every input but the condition (first input) return input_shapes[1:] diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 4060c38365..97c438b0ea 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -885,7 +885,7 @@ def connection_pattern(self, node): self._connection_pattern = ret return ret - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # TODO: Use `fgraph.shape_feature` to do this instead. out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes) diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 57d34bbf25..028c854854 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp): def make_node(self, x): return Apply(self, [x], [x.type()]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback(self, args, outputs, g_outs): @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes @@ -251,8 +251,8 @@ def __reduce__(self): ) return load_back, (mod, name) - def _infer_shape(self, fgraph, node, input_shapes): - return self.__infer_shape(fgraph, node, input_shapes) + def _infer_shape(self, node, input_shapes): + return self.__infer_shape(node, input_shapes) def as_op(itypes, otypes, infer_shape=None): @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None): It takes an optional infer_shape parameter that should be a callable with this signature: - def infer_shape(fgraph, node, input_shapes): + def infer_shape(node, input_shapes): ... return output_shapes diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index ebfa067b71..e048b4188f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -1,3 +1,4 @@ +import inspect import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Sequence @@ -120,6 +121,24 @@ class Op(MetaObject): as nodes with these Ops must be rebuilt even if the input types haven't changed. """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + method = cls.__dict__.get("infer_shape") + if method is None: + return + params = inspect.signature(method).parameters + if len(params) == 4: + warnings.warn( + f"{cls.__module__}.{cls.__qualname__}.infer_shape takes a " + "deprecated `fgraph` parameter; drop it from the signature. " + "The parameter will be passed as None.", + DeprecationWarning, + stacklevel=2, + ) + cls.infer_shape = lambda self, node, input_shapes, _old=method: _old( + self, None, node, input_shapes + ) + def make_node(self, *inputs: Variable) -> Apply: """Construct an `Apply` node that represent the application of this operation to the given inputs. diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 3e7264fcf9..31e0feba97 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -103,7 +103,7 @@ def __str__(self): args.append("inplace") return f"if{{{','.join(args)}}}" - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): # By construction, corresponding then/else pairs have the same number # of dimensions diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index c2962e4d7d..87f52c3c4b 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props): def c_code_cache_version(self): return (2,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def do_constant_folding(self, fgraph, node): diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 8c92c819c2..a0cb073f54 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2251,7 +2251,7 @@ def perform(self, node, inputs, output_storage): self.t_call = t_call self.t_fn = t_fn - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # input_shapes correspond to the shapes of node.inputs for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 771135b1c6..a6bf82edf7 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -494,7 +494,7 @@ def pullback(self, inputs, outputs, gout): disconnected_type(), ] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # node.inputs[3] is of length as we only support sparse matrix. return [(node.inputs[3][0], node.inputs[3][1])] @@ -584,7 +584,7 @@ def perform(self, node, inputs, outputs): g_out[0] = gout_data - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -629,7 +629,7 @@ def pullback(self, inputs, outputs, outputs_gradients): else: return [Cast(inputs[0].dtype)(gz)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return ins_shapes def __str__(self): @@ -742,7 +742,7 @@ def pullback(self, inputs, outputs, gout): else: return [SparseFromDense(x.type.format)(gz)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -806,7 +806,7 @@ def pullback(self, inputs, outputs, gout): ) return (gx,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -820,7 +820,7 @@ class GetItemList(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[1][0], shapes[0][1])] def make_node(self, x, index): @@ -865,7 +865,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItemListGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, index, gz): @@ -958,7 +958,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItem2ListsGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, ind1, ind2, gz): @@ -1139,7 +1139,7 @@ class GetItemScalar(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def make_node(self, x, index): @@ -1239,7 +1239,7 @@ def pullback(self, inputs, outputs, gout): assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (transpose(gz),) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0][::-1]] @@ -1288,7 +1288,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1339,7 +1339,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1438,7 +1438,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [square_diagonal(gz)] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): return [(minimum(*shapes[0]),)] @@ -1498,7 +1498,7 @@ def perform(self, node, inputs, outputs): def pullback(self, inputs, outputs, output_grad): return [output_grad[0]] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes def __str__(self): @@ -1614,7 +1614,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[1] for shape in ins_shapes) return [(ins_shapes[0][0], d)] @@ -1711,7 +1711,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[0] for shape in ins_shapes) return [(d, ins_shapes[0][1])] @@ -1800,7 +1800,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [gz] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes @@ -1880,7 +1880,7 @@ def perform(self, node, inp, out_): (data, indices, indptr), shape=out_shape, dtype=values.dtype ) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x = node.inputs[0] return [[x[0], x[1]]] diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 107df82bfc..a65a4e3614 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -327,7 +327,7 @@ def pullback(self, inputs, outputs, gout): r = psb.SparseFromDense(o_format)(r) return [r] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): r = None if self.axis is None: r = [()] @@ -406,7 +406,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -467,7 +467,7 @@ def pullback(self, inputs, outputs, gout): derivative = {True: gz, False: None} return [derivative[b] for b in is_continuous] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -509,7 +509,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_dense_variable(gz) return psb.sp_ones_like(x) * gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -569,7 +569,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, sp_sum(gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -699,7 +699,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return y * gz, x * gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -788,7 +788,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return y * gz, psb.dense_from_sparse(x * gz) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -871,7 +871,7 @@ def pullback(self, inputs, outputs, gout): return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -989,7 +989,7 @@ def perform(self, node, inputs, outputs): self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format) ) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1034,7 +1034,7 @@ def perform(self, node, inputs, outputs): o = np.asarray(o) out[0] = o - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1284,7 +1284,7 @@ def pullback(self, inputs, outputs, gout): rval[1] = psb.dense_from_sparse(rval[1]) return rval - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1412,7 +1412,7 @@ def pullback(self, inputs, outputs, gout): (g_out,) = gout return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1596,7 +1596,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1731,7 +1731,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1829,7 +1829,7 @@ def pullback(self, inputs, outputs, gout): return rval - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[2]] @@ -1842,7 +1842,7 @@ class Dot(Op): def __str__(self): return "Sparse" + self.__class__.__name__ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes x, y = node.inputs if x.ndim == 2 and y.ndim == 2: diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 2866f1aec0..1a2c75a6a8 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -176,7 +176,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ return code - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[3]] def c_code_cache_version(self): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 6f03bc2ed4..48528e7696 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -637,7 +637,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(s) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -698,7 +698,7 @@ def perform(self, node, inputs, output_storage): # not using .item() because that returns a Python scalar, not a numpy scalar output_storage[0][0] = inputs[0][()] - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -1379,7 +1379,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.eye(n, m, k, dtype=self.dtype) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] @@ -1708,7 +1708,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (5,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs[1:]] def connection_pattern(self, node): @@ -1984,7 +1984,7 @@ def c_code(self, node, name, inp, out_, props): """ return ret - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [(len(ishapes),)] def pullback(self, inputs, outputs, output_gradients): @@ -2272,7 +2272,7 @@ def perform(self, node, inputs, outputs_storage): for out_storage, out in zip(outputs_storage, split_outs, strict=False): out_storage[0] = out - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): axis = node.inputs[1] splits = node.inputs[2] shp_x, _shp_axis, _shp_splits = in_shapes @@ -2728,7 +2728,7 @@ def pullback(self, inputs, outputs, grads): return rval - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, ge # ishapes[0] contains the size of the axis on which we join @@ -3282,7 +3282,7 @@ def make_node(self, start, stop, step): return Apply(self, inputs, outputs) @config.change_flags(warn_float64="ignore") - def infer_shape(self, fgraph, node, i_shapes): + def infer_shape(self, node, i_shapes): from pytensor.tensor.math import ceil, maximum # Note start, stop and step can be float numbers. @@ -3659,7 +3659,7 @@ def perform(self, node, inp, out): self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): from pytensor.tensor.math import maximum shp_x = in_shapes[0] @@ -3911,7 +3911,7 @@ def pullback(self, inputs, outputs, gout): x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2)) return [x_grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): from pytensor.tensor.math import clip, minimum (in_shape,) = shapes @@ -4243,7 +4243,7 @@ def __init__(self, mode): assert mode in ("raise", "wrap", "clip") self.mode = mode - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): a_shape, choices_shape = shapes if choices_shape is None: # choices is a TypedList, not a tensor; no shape to broadcast @@ -4274,9 +4274,7 @@ def make_node(self, a, choices): choice = as_tensor_variable(choices) choice_dtype = choice.dtype - (out_shape,) = self.infer_shape( - None, None, [shape_tuple(a), shape_tuple(choice)] - ) + (out_shape,) = self.infer_shape(None, [shape_tuple(a), shape_tuple(choice)]) static_out_shape = () for s in out_shape: @@ -4379,7 +4377,7 @@ def c_code(self, node, name, inputs, out_, sub): """ return str - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs] def c_code_cache_version(self): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 0b0dcdfc2d..edb70ac0e9 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -243,7 +243,7 @@ def perform(self, node, inputs, out_storage): out += y out_storage[0][0] = np.asarray(out, dtype=y.dtype) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -316,7 +316,7 @@ def perform(self, node, inputs, output_storage): A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) output_storage[0][0] = A - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -941,7 +941,7 @@ def perform(self, node, inp, out): z += a * np.dot(x, y) zout[0] = z - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): z_shape, _, x_shape, y_shape, _ = input_shapes return [ ( @@ -1146,7 +1146,7 @@ def make_node(self, x, y): def perform(self, node, inputs, output_storage): output_storage[0][0] = np.dot(*inputs) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = """ @@ -1249,7 +1249,7 @@ def perform(self, node, inp, out): e.args = (*e.args, x.shape, y.shape) raise - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz @@ -1638,7 +1638,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index ecc4ad92d1..2c8a2c99bc 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,6 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op @@ -321,9 +320,7 @@ def make_node(self, *inputs): def batch_ndim(self, node: Apply) -> int: return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) - def infer_shape( - self, fgraph, node, input_shapes - ) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, input_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i @@ -354,13 +351,10 @@ def extract_core_shape_from_infer_shape(): return_dummy_inputs=True, propagate_unbatched_core_inputs=True, ) - dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes ] - core_output_shapes = core_op_infer_shape( - dummy_fgraph, dummy_core_node, core_input_shapes - ) + core_output_shapes = core_op_infer_shape(dummy_core_node, core_input_shapes) if not dummy_core_inputs: # All inputs are unbatched, so the core_shape can be used as is diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a2a6e2a89f..6366ec42d5 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -257,7 +257,7 @@ def perform(self, node, inp, out): new_shape.insert(augm, 1) out[0][0] = res.reshape(new_shape) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishp,) = shapes # transpose rval = [ishp[i] for i in self.shuffle] @@ -763,7 +763,7 @@ def _check_runtime_broadcast(node, inputs): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, i_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor.extra_ops import broadcast_shape out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) @@ -1434,7 +1434,7 @@ def perform(self, node, inp, out): output[0] = np.asarray(out, dtype=out_dtype) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes axis = self.axis if axis is None: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 61a8003cfd..34542e9bdb 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -150,7 +150,7 @@ def make_node(self, x, v, sorter=None): raise TypeError("sorter must be an integer vector", sorter.type) return Apply(self, [x, v, sorter], [out_type()]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] def perform(self, node, inputs, output_storage): @@ -340,7 +340,7 @@ def pullback(self, inputs, outputs, output_gradients): f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' ) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes def c_code(self, node, name, inames, onames, sub): @@ -717,7 +717,7 @@ def pullback(self, inputs, outputs, gout): return [gx, disconnected_type()] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) @@ -849,7 +849,7 @@ def perform(self, node, inputs, out_): (out,) = out_ out[0] = np.bartlett(M) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): temp = node.inputs[0] M = ptb.switch(lt(temp, 0), ptb.cast(0, temp.dtype), temp) return [[M]] @@ -892,7 +892,7 @@ class FillDiagonal(Op): # See function fill_diagonal for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val): @@ -993,7 +993,7 @@ class FillDiagonalOffset(Op): # See function fill_diagonal_offset for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val, offset): @@ -1240,7 +1240,7 @@ def perform(self, node, inputs, output_storage): else: output_storage[0][0] = outs - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): [x_shape] = i0_shapes shape0_op = Shape_i(0) out_shapes = [(shape0_op(out),) for out in node.outputs] @@ -1310,7 +1310,7 @@ def make_node(self, indices, dims): [out_type() for _i in range(ptb.get_vector_length(dims))], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] * len(node.outputs) def perform(self, node, inp, out): @@ -1387,7 +1387,7 @@ def make_node(self, *inp): [out_type()], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def perform(self, node, inp, out): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 9d35955c6f..260fc2dc09 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -100,7 +100,7 @@ def make_node(self, a, n, axis): ], ) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): shape_a = in_shapes[0] n = node.inputs[1] axis = node.inputs[2] diff --git a/pytensor/tensor/linalg/constructors.py b/pytensor/tensor/linalg/constructors.py index 96ced60f9e..fbcc6d486a 100644 --- a/pytensor/tensor/linalg/constructors.py +++ b/pytensor/tensor/linalg/constructors.py @@ -34,7 +34,7 @@ def pullback(self, inputs, outputs, gout): ] return [gout[0][slc] for slc in slices] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): first, second = unzip(shapes, n=2, strict=True) return [(pt.add(*first), pt.add(*second))] diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 0fa7a34b3b..556ead9bf2 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -33,7 +33,7 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def make_node(self, x): diff --git a/pytensor/tensor/linalg/decomposition/eigen.py b/pytensor/tensor/linalg/decomposition/eigen.py index f2afe7cf39..ba617d11f2 100644 --- a/pytensor/tensor/linalg/decomposition/eigen.py +++ b/pytensor/tensor/linalg/decomposition/eigen.py @@ -65,7 +65,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = w.astype(dtype, copy=False) outputs[1][0] = v.astype(dtype, copy=False) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shapes,) = shapes n, _ = x_shapes @@ -206,7 +206,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self return type(self)(**new_props) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n,), (n, n)] @@ -416,7 +416,7 @@ def make_node(self, a, b=None): w = vector(dtype=out_dtype, shape=(N,)) return Apply(self, inputs, [w]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [ (n,), diff --git a/pytensor/tensor/linalg/decomposition/lu.py b/pytensor/tensor/linalg/decomposition/lu.py index 2b4edb621b..db8ac60fe5 100644 --- a/pytensor/tensor/linalg/decomposition/lu.py +++ b/pytensor/tensor/linalg/decomposition/lu.py @@ -42,7 +42,7 @@ def __init__(self, *, permute_l=False, overwrite_a=False, p_indices=False): if self.overwrite_a: self.destroy_map = {0: [0]} if self.permute_l else {1: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] if self.permute_l: return [(n, n), (n, n)] @@ -258,7 +258,7 @@ def make_node(self, A): return Apply(self, [A], [LU, pivots]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n, n), (n,)] diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 9e9270259d..9d81e8eb7b 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -103,7 +103,7 @@ def make_node(self, x): return Apply(self, [x], outputs) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape diff --git a/pytensor/tensor/linalg/decomposition/schur.py b/pytensor/tensor/linalg/decomposition/schur.py index 737ca9dc1d..af3e2035ea 100644 --- a/pytensor/tensor/linalg/decomposition/schur.py +++ b/pytensor/tensor/linalg/decomposition/schur.py @@ -146,7 +146,7 @@ def perform(self, node, inputs, outputs): T_out[0] = T Z_out[0] = Z - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0], shapes[0]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": @@ -489,7 +489,7 @@ def perform(self, node, inputs, outputs): alpha_out[0] = alpha beta_out[0] = beta - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): A_shape, B_shape = shapes if self.return_eigenvalues: return [A_shape, B_shape, (A_shape[0],), (A_shape[0],), A_shape, B_shape] diff --git a/pytensor/tensor/linalg/decomposition/svd.py b/pytensor/tensor/linalg/decomposition/svd.py index 584eae1419..361b5cdd34 100644 --- a/pytensor/tensor/linalg/decomposition/svd.py +++ b/pytensor/tensor/linalg/decomposition/svd.py @@ -93,7 +93,7 @@ def perform(self, node, inputs, outputs): (s,) = outputs s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape K = ptm.minimum(M, N) diff --git a/pytensor/tensor/linalg/inverse.py b/pytensor/tensor/linalg/inverse.py index 82878c8777..6c205d6ae7 100644 --- a/pytensor/tensor/linalg/inverse.py +++ b/pytensor/tensor/linalg/inverse.py @@ -61,7 +61,7 @@ def pullback(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [list(reversed(shapes[0]))] @@ -159,7 +159,7 @@ def pushforward(self, inputs, outputs, eval_points): return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes @@ -187,7 +187,7 @@ def perform(self, node, inputs, outputs): (x,) = outputs x[0] = np.linalg.tensorinv(a, self.ind) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): sp = shapes[0][self.ind :] + shapes[0][: self.ind] return [sp] diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 42d49b10f4..542b174c4c 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -74,7 +74,7 @@ def pullback(self, inputs, outputs, output_grads): return [expm(aug)[..., :n, n:]] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] diff --git a/pytensor/tensor/linalg/solvers/core.py b/pytensor/tensor/linalg/solvers/core.py index 9805d86c8f..39a46e84a8 100644 --- a/pytensor/tensor/linalg/solvers/core.py +++ b/pytensor/tensor/linalg/solvers/core.py @@ -71,7 +71,7 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): Ashape, Bshape = shapes rows = Ashape[1] if len(Bshape) == 1: diff --git a/pytensor/tensor/linalg/solvers/linear_control.py b/pytensor/tensor/linalg/solvers/linear_control.py index 14478f8e03..2fb47fac34 100644 --- a/pytensor/tensor/linalg/solvers/linear_control.py +++ b/pytensor/tensor/linalg/solvers/linear_control.py @@ -82,7 +82,7 @@ def perform(self, node, inputs, outputs_storage): Y *= scale X[0] = Y - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[2]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": diff --git a/pytensor/tensor/linalg/summary.py b/pytensor/tensor/linalg/summary.py index bba599e17f..c76753885f 100644 --- a/pytensor/tensor/linalg/summary.py +++ b/pytensor/tensor/linalg/summary.py @@ -71,7 +71,7 @@ def pullback(self, inputs, outputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def __str__(self): @@ -106,7 +106,7 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(), ()] def __str__(self): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..7e93db8c01 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -254,7 +254,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (3,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes if self.axis is None: return [()] @@ -3106,7 +3106,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [[xshp[0], yshp[1]]] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 34479e142c..7201d631e8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -306,7 +306,7 @@ def extract_batch_shape(p, ps, n): return shape - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): _, size, *dist_params = node.inputs _, _, *param_shapes = input_shapes diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index aab187a4ef..3a88a0a39f 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -64,7 +64,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] return Apply(self, [x], [output_type]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape] = shapes joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int) return [self.output_shapes(input_shape, joined_shape)] @@ -185,7 +185,7 @@ def make_node(self, x, shape): ) return Apply(self, [x, shape], [output]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape, _] = shapes _, shape = node.inputs output_shapes = list(input_shape) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index ecb3435030..0a6d59c4be 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -106,8 +106,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): else: input_shapes = [tuple(inp.shape) for inp in node.inputs] core_shapes = [ - out_shape[batch_ndim:] - for out_shape in op.infer_shape(None, node, input_shapes) + out_shape[batch_ndim:] for out_shape in op.infer_shape(node, input_shapes) ] core_shapes = [ diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 2b2060c3e7..d4bbfa60e1 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -136,12 +136,10 @@ def get_node_infer_shape(self, node): shape_infer = self.default_infer_shape try: - o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) + o_shapes = shape_infer(node, [self.shape_of[r] for r in node.inputs]) except ShapeError: o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + node, [self.shape_of[r] for r in node.inputs] ) except NotImplementedError as e: raise NotImplementedError( @@ -162,7 +160,7 @@ def get_node_infer_shape(self, node): else: warn(msg) o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + node, [self.shape_of[r] for r in node.inputs] ) return o_shapes @@ -231,7 +229,7 @@ def shape_tuple(self, r): return None return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - def default_infer_shape(self, fgraph, node, i_shapes): + def default_infer_shape(self, node, i_shapes): """Return a list of shape tuple or None for the outputs of node. This function is used for Ops that don't implement infer_shape. diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..74c445f2f5 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -81,7 +81,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(np.shape(x), dtype="int64") - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [[len(in_shapes[0])]] def connection_pattern(self, node): @@ -297,7 +297,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [()] def connection_pattern(self, node): @@ -452,7 +452,7 @@ def perform(self, node, inp, out_): ) out[0] = x - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] # Use x shape if specified dim is None, otherwise the specified shape @@ -727,7 +727,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self(eval_points[0], *inputs[1:], return_list=True) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, maximum, mul # inputs[1] can contain at most one value of '-1', meaning the actual diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index c3133e3c15..51dd796a52 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -82,7 +82,7 @@ def make_node(self, in1, in2, full_mode): out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [in1, in2, full_mode], [out]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): _, _, full_mode = node.inputs in1_shape, in2_shape, _ = shapes out_shape = [ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index af695d9e42..c911be988d 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -54,7 +54,7 @@ def perform(self, node, inputs, output_storage): z = output_storage[0] z[0] = np.sort(a, axis, self.kind) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] @@ -185,7 +185,7 @@ def perform(self, node, inputs, output_storage): dtype=node.outputs[0].dtype, ) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 403331cef6..4638fa3e70 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -922,7 +922,7 @@ def perform(self, node, inputs, out_): cdata = unflatten_index_variables(index_variables, self.idx_list) out[0] = np.asarray(x.__getitem__(tuple(cdata))) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): def _is_constant(const, x): return isinstance(const, Constant) and const.data.item() == x @@ -1767,7 +1767,7 @@ def add_to_zview(self, name, x, fail): {fail}; }}""" - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def pushforward(self, inputs, outputs, eval_points): @@ -1945,7 +1945,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] @@ -2295,7 +2295,7 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = x - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, _y, _ilist = ishapes return [x] @@ -2469,7 +2469,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): def is_bool_index(idx): return ( isinstance(idx, np.bool_ | bool) @@ -2695,7 +2695,7 @@ def perform(self, node, inputs, out_): else: np.add.at(out[0], tuple(full_indices), y) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [ishapes[0]] def connection_pattern(self, node): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 3e02c75ce9..09a8d8fe1f 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -30,7 +30,7 @@ class XTypeCastOp(TypeCastingOp): This is like a `ViewOp` but without the expectation the input and output have identical types. """ - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def vectorize_node( diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index a30ed6475d..1954c6cbb6 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -65,7 +65,7 @@ def test_infer_shape(self): x = dmatrix("x") y = dvector("y") - def infer_shape(fgraph, node, shapes): + def infer_shape(node, shapes): _x, y = shapes return [y] diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 741d5f01a5..b3c663b5f0 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -310,7 +310,7 @@ def grad(self, inputs, gout): else: return (gz,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def test_grad_fail(self): diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 9ab4e591f1..f017d539fe 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -162,7 +162,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - # def infer_shape(self, fgraph, node, (xshp,)): + # def infer_shape(self, node, (xshp,)): # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] identity_noshape = IdentityNoShape() @@ -179,7 +179,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - def infer_shape(self, fgraph, node, xshp_): + def infer_shape(self, node, xshp_): # Could also just return. (xshp,) = xshp_ return (xshp,) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index dfded3fbc3..c4023b3ea5 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -310,7 +310,7 @@ def perform(self, node, inputs, outputs): c[0] = np.arange(a.size + b.size, dtype=config.floatX) d[0] = np.arange(a.sum() + b.sum(), dtype=config.floatX) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # First output shape depends only on input_shapes # Second output shape depends on input values a_identity, b_identity = node.inputs @@ -362,7 +362,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): y = node.outputs[0] # Apparently it's valid to return integers in infer_shape. # DimShuffle does this. Modify test if that is no longer allowed. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 913a1036ff..6b72864763 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -873,7 +873,7 @@ def test_partial_static_shape_info(self): x_inferred_shape = (ps.constant(1), ps.constant(1)) res_shape = z.owner.op.infer_shape( - None, z.owner, [x_inferred_shape, x_inferred_shape] + z.owner, [x_inferred_shape, x_inferred_shape] ) assert len(res_shape) == 1 @@ -902,7 +902,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), ) in_1_shape = (ps.constant(1), ps.constant(1)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_1_shape]) for out in outs: assert out[0].eval() == 1 assert out[1].eval() == 1 @@ -911,7 +911,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) ) in_2_shape = (ps.constant(3), ps.constant(3)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_2_shape]) for out in outs: assert out[0].eval() == 3 assert out[1].eval() == 3 @@ -924,7 +924,7 @@ def test_shape_types(self): assert isinstance(z.owner.op, Elemwise) - (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) + (out_shape,) = z.owner.op.infer_shape(z.owner, [(lscalar(), 1), (50, 10)]) assert all(isinstance(v.type, TensorType) for v in out_shape) From fa63548635e24a08891005c3621644f682559ce0 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 17:30:54 +0200 Subject: [PATCH 04/10] Do not leak stale variables in ShapeFeature --- pytensor/compile/builders.py | 81 +- pytensor/tensor/random/rewriting/basic.py | 2 +- pytensor/tensor/rewriting/basic.py | 4 +- pytensor/tensor/rewriting/shape.py | 936 +++++++--------------- pytensor/tensor/rewriting/subtensor.py | 25 +- pytensor/tensor/shape.py | 16 +- pytensor/tensor/utils.py | 15 +- tests/compile/test_builders.py | 4 +- tests/tensor/random/test_basic.py | 8 +- tests/tensor/rewriting/test_shape.py | 26 +- tests/xtensor/test_rewriting.py | 17 +- 11 files changed, 381 insertions(+), 753 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 97c438b0ea..910b8d842d 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from itertools import chain from typing import cast from pytensor.compile.maker import function @@ -24,26 +23,13 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError def infer_shape(outs, inputs, input_shapes): - """ - Compute the shape of the outputs given the shape of the inputs of an PyTensor - graph. - - We do it this way to avoid compiling the inner function just to get - the shape. Changes to ShapeFeature could require changes in this function. - - """ - # We use a ShapeFeature because it has all the necessary logic - # inside. We don't use the full ShapeFeature interface, but we - # let it initialize itself with an empty fgraph, otherwise we will - # need to do it manually - - # TODO: ShapeFeature should live elsewhere + """Compute the shape of the outputs given the shape of the inputs of a PyTensor graph.""" from pytensor.tensor.rewriting.shape import ShapeFeature for inp, inp_shp in zip(inputs, input_shapes, strict=True): @@ -51,43 +37,36 @@ def infer_shape(outs, inputs, input_shapes): assert len(inp_shp) == inp.type.ndim shape_feature = ShapeFeature() - fgraph = FunctionGraph([], [], features=[shape_feature]) - for v in chain.from_iterable(s for s in input_shapes if s is not None): - # Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before - if (node := v.owner) is not None: - fgraph.import_node(node, import_missing=True) - - # Initialize shape_of with the input shapes - for inp, inp_shp in zip(inputs, input_shapes, strict=True): - shape_feature.set_shape(inp, inp_shp, override=True) + output_shapes = [shape_feature.shape_tuple(o) for o in outs] - def local_traverse(out): - """ - Go back in the graph, from out, adding computable shapes to shape_of. + # Shape expressions for root inputs are Shape_i(inp, i). + # Replace those with the caller-provided input_shapes. + replacements = {} + for inp, shp in zip(inputs, input_shapes, strict=True): + if shp is None: + continue + per_dim = shape_feature._shape_i_cache.get(inp) + if per_dim is None: + continue + for i, s in enumerate(shp): + cached = per_dim.get(i) + if cached is not None: + replacements[cached] = s + + if replacements: + flat = [s for tup in output_shapes if tup is not None for s in tup] + flat_replaced = graph_replace(flat, replacements, strict=False) + result = [] + idx = 0 + for tup in output_shapes: + if tup is None: + result.append(None) + else: + result.append(tuple(flat_replaced[idx : idx + len(tup)])) + idx += len(tup) + return result - """ - if out in shape_feature.shape_of: - # Its shape is already known - return - elif out.owner is None: - # This is an input of the graph - shape_feature.init_r(out) - else: - # Recurse over inputs - for inp in out.owner.inputs: - if inp not in shape_feature.shape_of: - local_traverse(inp) - - # shape_feature.on_import does not actually use an fgraph - # It will call infer_shape and set_shape appropriately - dummy_fgraph = None - shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy") - - ret = [] - for o in outs: - local_traverse(o) - ret.append(shape_feature.shape_of[o]) - return ret + return output_shapes def construct_nominal_fgraph( diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index b658cc98cb..ac01e83ef6 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -274,7 +274,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # Use shape_feature to facilitate inferring final shape. # Check that neither the RV nor the old Subtensor are in the shape graph. - output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None) + output_shape = shape_feature.shape_tuple(indexed_rv) if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)): return None diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 99522237ce..48ef217c73 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -124,8 +124,8 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: def get_simplified_shape(x: TensorVariable, *, fgraph) -> tuple: """Return a simplified shape tuple for ``x``: shape_feature → static → ``x.shape``.""" try: - return fgraph.shape_feature.shape_of[x] - except (AttributeError, KeyError): + return fgraph.shape_feature.shape_tuple(x) + except AttributeError: pass static_shape = x.type.shape diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index d4bbfa60e1..6369c2ad4d 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -1,6 +1,4 @@ -import traceback -from io import StringIO -from typing import cast as type_cast +from collections import deque from warnings import warn import numpy as np @@ -15,8 +13,6 @@ copy_stack_trace, node_rewriter, ) -from pytensor.graph.traversal import ancestors -from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.tensor.basic import ( Alloc, MakeVector, @@ -30,13 +26,12 @@ stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError, ShapeError +from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, register_useless, - topo_constant_folding, ) from pytensor.tensor.shape import ( Reshape, @@ -50,614 +45,289 @@ AdvancedIncSubtensor1, IncSubtensor, Subtensor, - get_idx_list, ) -from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -class ShapeFeature(Feature): - r"""A `Feature` that tracks shape information in a graph. - - This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with - `Shape_i` and `MakeVector` `Op`\s. - - This `Feature` and its associated rewrites have several goals: - - 1. to "lift" `Shape`\s to as close to the inputs as possible, - 2. to infer the shape of every node in the graph in terms of the - input shapes, and - 3. remove fill `Op`\s (e.g. `Second`) from the graph. - - Lifting shapes as close to the inputs as possible is important for - canonicalization because it is very bad form to have to compute - something just to know how big it will be. Firstly, it is a waste - of time to compute such outputs. But it is important to get rid - of these outputs as early as possible in the compilation process - because the extra computations make it appear as if many internal - graph nodes have multiple clients. Many rewrites refuse to - work on nodes with multiple clients. - - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven rewrites. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. +class _ShapeOfProxy: + """Dict-like proxy so ``shape_feature.shape_of[var]`` keeps working.""" - Notes - ----- - To use this shape information in rewrites, use the - ``shape_of`` dictionary. + def __init__(self, feature): + self._feature = feature - For example: + def __getitem__(self, var): + result = self._feature.shape_tuple(var) + if result is None: + raise KeyError(var) + return result - .. code-block:: python + def __contains__(self, var): + return hasattr(var.type, "ndim") - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return - shape_of_output_zero = shape_of[node.output[0]] +class ShapeFeature(Feature): + r"""Lazy `Feature` that provides shape information for the graph it tracks. - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. + Shapes are derived on demand by calling each ``Op.infer_shape`` on the + current (live) node inputs, recursing toward the graph inputs. Use: - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes PyTensor - constants?? That would be confusing. + - ``get_shape(var, i)`` / ``shape_tuple(var)`` for the shape expressed in + terms of the graph inputs (recursing through the ancestors of ``var``); + - ``same_shape(x, y)`` to statically compare two shapes. + Inferred shapes are cached per node, but the cache is invalidated whenever + an ancestor input changes (``on_change_input``), so the expressions handed + out always reference live graph variables. """ - def get_node_infer_shape(self, node): - try: - shape_infer = node.op.infer_shape - except AttributeError: - shape_infer = self.default_infer_shape - - try: - o_shapes = shape_infer(node, [self.shape_of[r] for r in node.inputs]) - except ShapeError: - o_shapes = self.default_infer_shape( - node, [self.shape_of[r] for r in node.inputs] - ) - except NotImplementedError as e: - raise NotImplementedError( - "Code called by infer_shape failed raising a " - "NotImplementedError. Raising NotImplementedError to " - "indicate that a shape cannot be computed is no longer " - "supported, and one should now use ShapeError " - f"instead. The original exception message is: {e}" - ).with_traceback(e.__traceback__) - except Exception as e: - msg = ( - f"Failed to infer_shape from Op {node.op}.\nInput shapes: " - f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " - f"{type(e)}\nException message: {e!s}\nTraceback: {traceback.format_exc()}" - ) - if config.on_shape_error == "raise": - raise Exception(msg).with_traceback(e.__traceback__) - else: - warn(msg) - o_shapes = self.default_infer_shape( - node, [self.shape_of[r] for r in node.inputs] - ) - - return o_shapes - - def get_shape(self, var, idx): - """Rewrites can call this to get a `Shape_i`. - - It is better to call this then use directly ``shape_of[var][idx]`` - as this method should update `shape_of` if needed. - - TODO: Up to now, we don't update it in all cases. Update in all cases. - """ - r = self.shape_of[var][idx] - if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables - ): - assert var.owner - node = var.owner - # recur on inputs - for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) - assert len(o_shapes) == len(node.outputs) - - # Only change the variables and dimensions that would introduce - # extra computation - for new_shps, out in zip(o_shapes, node.outputs, strict=True): - if not hasattr(out.type, "ndim"): - continue - - merged_shps = list(self.shape_of[out]) - changed = False - for i in range(out.type.ndim): - n_r = merged_shps[i] - if ( - n_r.owner - and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables - ): - changed = True - merged_shps[i] = new_shps[i] - if changed: - self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") + def __init__(self): + self.fgraph: FunctionGraph | None = None + # node -> tuple of (tuple of shape vars) per output, lazily populated. + # The cached shape of a node embeds the shapes of its whole ancestor cone, + # so a change anywhere above a cached node invalidates it. + self._cache: dict = {} + # Nodes whose input changed since the last query. Invalidation is deferred + # to the next query, which walks ``fgraph.clients`` downstream from these + # nodes; the live graph already encodes the dependencies, so no reverse map + # is kept. + self._stale: set = set() + # var -> {i: Shape_i(i)(v)}, ensures Apply identity for leaves + self._shape_i_cache: dict = {} + self.lscalar_one = constant(1, dtype="int64") + # Compat: scheduled replacements for local_track_shape_i + self.scheduled: dict = {} + + def _shape_i_var(self, v, i): + per_dim = self._shape_i_cache.get(v) + if per_dim is not None: + cached = per_dim.get(i) + if cached is not None: + return cached else: - s = Shape_i(i)(r) - try: - s = get_scalar_constant_value(s) - except NotScalarConstantError: - pass - return s + per_dim = {} + self._shape_i_cache[v] = per_dim + if hasattr(v.type, "shape") and v.type.shape[i] is not None: + res = constant(v.type.shape[i], dtype="int64") + else: + res = Shape_i(i)(v) + per_dim[i] = res + return res + + def _coerce_shape_element(self, element, node): + """Validate and normalize a single shape element from infer_shape.""" + if isinstance(element, np.ndarray): + if element.ndim != 0: + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"ndarray for shape element: {element!r}" + ) + element = element.item() + if isinstance(element, Variable): + if element.type.dtype not in integer_dtypes: + raise TypeError( + f"infer_shape for {node.op} returned a non-integer " + f"Variable for shape element: {element!r}" + ) + if getattr(element.type, "ndim", 0): + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"Variable for shape element: {element!r}" + ) + if element.type.dtype != "int64": + if isinstance(element, Constant): + return constant(int(element.data), dtype="int64") + return cast(element, "int64") + return element + if isinstance(element, int | np.integer): + if int(element) < 0: + raise ValueError( + f"infer_shape for {node.op} returned a negative shape: {int(element)}" + ) + return constant(int(element), dtype="int64") + raise TypeError( + f"infer_shape for {node.op} returned an unsupported " + f"shape element of type {type(element).__name__}: {element!r}" + ) - def shape_tuple(self, r): - """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): - # This happen for NoneConst. - return None - return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) + def _get_node_shapes(self, node): + """Return validated per-output shape tuples for ``node``.""" + if self._stale: + self._flush_stale() + + cache = self._cache + cached = cache.get(node) + if cached is not None: + return cached + + # Fill the cache for ``node`` and its ancestor cone bottom-up with an + # explicit stack. Recursing through ``get_shape`` would overflow the + # Python stack on deep graphs; here a node is computed only once every + # input-producing node it needs is cached, so ``_compute_node_shapes`` + # reads its input shapes straight from the cache instead of recursing. + stack = [node] + while stack: + top = stack[-1] + if top in cache: + stack.pop() + continue + ready = True + for inp in top.inputs: + inp_node = inp.owner + if ( + inp_node is not None + and inp_node not in cache + and hasattr(inp.type, "ndim") + # A fully-static input shape is read without touching its + # owner (see ``get_shape``), so don't bother caching it. + and any(s is None for s in inp.type.shape) + ): + stack.append(inp_node) + ready = False + if ready: + stack.pop() + cache[top] = self._compute_node_shapes(top) - def default_infer_shape(self, node, i_shapes): - """Return a list of shape tuple or None for the outputs of node. + return cache[node] - This function is used for Ops that don't implement infer_shape. - Ops that do implement infer_shape should use the i_shapes parameter, - but this default implementation ignores it. + def _compute_node_shapes(self, node): + """Call ``infer_shape`` for ``node`` and validate the result. + The caller must have cached the input-producing nodes already, so + ``get_shape`` resolves them from the cache without recursing. """ - rval = [] - for r in node.outputs: - try: - rval.append(self.shape_tuple(r)) - except AttributeError: - rval.append(None) - return rval - - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. - - The s_i argument was produced by the infer_shape() of an Op subclass. - - var: the variable that correspond to s_i. This is just for - error reporting. + input_shapes = [] + for inp in node.inputs: + if hasattr(inp.type, "ndim"): + input_shapes.append( + tuple(self.get_shape(inp, j) for j in range(inp.type.ndim)) + ) + else: + input_shapes.append(None) - """ - assert s_i is not None - - if s_i == 1: - return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, np.integer | int) or ( - isinstance(s_i, np.ndarray) and s_i.ndim == 0 - ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, tuple | list): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] - if ( - s_i.owner - and isinstance(s_i.owner.op, Subtensor) - and s_i.owner.inputs[0].owner - and isinstance(s_i.owner.inputs[0].owner.op, Shape) - ): - assert s_i.type.ndim == 0 - assert len(s_i.owner.op.idx_list) == 1 - - # The current Subtensor always put constant index in the graph. - # This was not True in the past. So call the Subtensor function - # that will return the right index. - idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) - assert len(idx) == 1 - idx = idx[0] + output_shapes = None + shape_infer = getattr(node.op, "infer_shape", None) + if shape_infer is not None: try: - i = get_scalar_constant_value(idx) - except NotScalarConstantError: + output_shapes = shape_infer(node, input_shapes) + except ShapeError: pass - else: - # Executed only if no exception was raised - x = s_i.owner.inputs[0].owner.inputs[0] - # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] - - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) - - def set_shape(self, r, s, override=False): - """Assign the shape `s` to previously un-shaped variable `r`. - - Parameters - ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. - - """ - if not override: - assert r not in self.shape_of, "r already in shape_of" - if s is None: - self.shape_of[r] = s - else: - if not isinstance(s, tuple | list): - raise TypeError("shapes must be tuple/list", (r, s)) - - if r.type.ndim != len(s): - sio = StringIO() - pytensor.printing.debugprint(r, file=sio, print_type=True) - raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" + except NotImplementedError: + pass + except Exception as exc: + if config.on_shape_error == "raise": + raise + warn( + f"Failed to infer_shape from Op {node.op}: " + f"{type(exc).__name__}: {exc}" ) - shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals( - get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + result = [] + for k, out in enumerate(node.outputs): + if not hasattr(out.type, "ndim"): + result.append(None) + continue + sh = None + if output_shapes is not None and k < len(output_shapes): + sh = output_shapes[k] + if sh is None or not isinstance(sh, list | tuple): + result.append( + tuple(self._shape_i_var(out, j) for j in range(out.type.ndim)) ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. - - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. - - """ - # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] + continue + coerced = [] + for j, s in enumerate(sh): + coerced.append(self._coerce_shape_element(s, node)) + result.append(tuple(coerced)) - # If other_shape has no information, call is pointless. - if other_shape is None: - return + return tuple(result) - if r in self.shape_of: - r_shape = self.shape_of[r] - else: - # If no info is known on r's shape, use other_shape - self.set_shape(r, other_shape) - return - if ( - other_r.owner - and r.owner - and other_r.owner.inputs == r.owner.inputs - and other_r.owner.op == r.owner.op - ): - # We are doing a merge, so the two shape graphs will be the - # same. This is only done so that we call `ancestors` less - # frequently. - return - - # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] - for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner - and isinstance(ps.owner.op, Shape_i) - and ps.owner.op.i == i - and ps.owner.inputs[0] in (r, other_r) - ): - # If other_shape[i] is uninformative, use r_shape[i]. - # For now, we consider 2 cases of uninformative other_shape[i]: - # - Shape_i(i)(other_r); - # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif any( - ( - r_shape[i] == anc - or ( - anc.owner - and isinstance(anc.owner.op, Shape) - and anc.owner.inputs[0] == r - ) - ) - for anc in ancestors([other_shape[i]]) - ): - # Another case where we want to use r_shape[i] is when - # other_shape[i] actually depends on r_shape[i]. In that case, - # we do not want to substitute an expression with another that - # is strictly more complex. Such a substitution could also lead - # to cycles: if (in the future) r_shape[i] gets replaced by an - # expression of other_shape[i], other_shape[i] may end up - # depending on itself. - merged_shape.append(r_shape[i]) - else: - merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "shape") - or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1) - ) - or self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - get_scalar_constant_value( - merged_shape[i], - only_process_constants=True, - raise_not_constant=False, - ) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def set_shape_i(self, r, i, s_i): - """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of - prev_shape = self.shape_of[r] - # prev_shape is a tuple, so we cannot change it inplace, - # so we build another one. - new_shape = [] - for j, s_j in enumerate(prev_shape): - if j == i: - new_shape.append(self.unpack(s_i, r)) - else: - new_shape.append(s_j) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[idx] != 1 - or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals( - get_scalar_constant_value(new_shape[idx], raise_not_constant=False) - ) - for idx in range(r.type.ndim) + def get_shape(self, var, idx): + """Return a symbolic expression for ``var.shape[idx]``.""" + if hasattr(var.type, "shape") and var.type.shape[idx] is not None: + return constant(var.type.shape[idx], dtype="int64") + + node = var.owner + if node is None: + return self._shape_i_var(var, idx) + + node_shapes = self._get_node_shapes(node) + out_idx = node.outputs.index(var) + sh = node_shapes[out_idx] + if sh is not None: + return sh[idx] + return self._shape_i_var(var, idx) + + def shape_tuple(self, var): + if not hasattr(var.type, "ndim"): + return None + return tuple(self.get_shape(var, i) for i in range(var.type.ndim)) + + @property + def shape_of(self): + """Deprecated back-compat shim. Use ``shape_tuple(var)`` instead.""" + warn( + "ShapeFeature.shape_of is deprecated; use shape_tuple(var) instead.", + DeprecationWarning, + stacklevel=2, ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def init_r(self, r): - """Register r's shape in the shape_of dictionary.""" - if r not in self.shape_of: - self.set_shape(r, self.shape_tuple(r)) - - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + return _ShapeOfProxy(self) def on_attach(self, fgraph): if hasattr(fgraph, "shape_feature"): raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - - if hasattr(self, "fgraph") and self.fgraph != fgraph: + if self.fgraph is not None and self.fgraph is not fgraph: raise Exception("This ShapeFeature is already attached to a graph") - self.fgraph = fgraph - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") - assert self.lscalar_one.type.dtype == "int64" - - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} - - for node in fgraph.toposort(): - self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} + self._cache.clear() + self._stale.clear() + self._shape_i_cache.clear() + self.scheduled.clear() self.fgraph = None - del fgraph.shape_feature - - def on_import(self, fgraph, node, reason): - if node.outputs[0] in self.shape_of: - # this is a revert, not really an import - for r in node.outputs + node.inputs: - assert r in self.shape_of - return - - for i, r in enumerate(node.inputs): - # make sure we have shapes for the inputs - self.init_r(r) - - o_shapes = self.get_node_infer_shape(node) + if hasattr(fgraph, "shape_feature"): + del fgraph.shape_feature - # this is packed information - # an element of o_shapes is either None or a tuple - # elements of the tuple can be either strings, or ints - if len(o_shapes) != len(node.outputs): - raise Exception( - f'The infer_shape method for the Op "{node.op}" returned a list ' - f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " - f" != len(node.outputs) = {len(node.outputs)}" - ) + def _flush_stale(self): + """Drop cached shapes for the changed nodes and everything downstream of them. - # Ensure shapes are in 'int64'. This is to make sure the assert - # found in the `local_useless_subtensor` rewrite does not fail. - for sh_idx, sh in enumerate(o_shapes): - if sh is None: - continue - if not isinstance(sh, list | tuple): - raise ValueError( - f"infer_shape of {node} didn't return a list of" - f" list. It returned '{o_shapes}'" - ) - new_shape = [] - for i, d in enumerate(sh): - # Note: we ignore any shape element that is not typed (i.e., - # does not have a 'dtype' attribute). This means there may - # still remain int elements that are int32 on 32-bit platforms, - # but this works with `local_useless_subtensor`, so for now we - # keep it this way. See #266 for a better long-term fix. - if getattr(d, "dtype", "int64") != "int64": - assert d.dtype in discrete_dtypes, (node, d.dtype) - assert str(d.dtype) != "uint64", node - new_shape += sh[len(new_shape) : i + 1] - if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") - else: - casted_d = cast(d, "int64") - new_shape[i] = casted_d - if new_shape: - # We replace the shape with wrong dtype by the one with - # 'int64'. - new_shape += sh[len(new_shape) :] - o_shapes[sh_idx] = tuple(new_shape) - - for r, s in zip(node.outputs, o_shapes, strict=True): - self.set_shape(r, s) + Walks ``fgraph.clients`` from each stale node, so a change above a cached node + evicts it no matter how many levels down it sits. + """ + cache = self._cache + clients = self.fgraph.clients + queue = deque(self._stale) + seen = set(self._stale) + self._stale = set() + while queue: + node = queue.popleft() + cache.pop(node, None) + for out in node.outputs: + for client_node, _ in clients.get(out, ()): + if client_node not in seen: + seen.add(client_node) + queue.append(client_node) + + def on_prune(self, fgraph, node, reason): + self._cache.pop(node, None) + self._stale.discard(node) + for out in node.outputs: + self._shape_i_cache.pop(out, None) def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) - - # This tells us that r and new_r must have the same shape if - # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) - - # change_input happens in two cases: - # 1) we are trying to get rid of r, or - # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - # This schedule is processed by `local_track_shape_i`. - for shpnode, idx in fgraph.clients[r] + [(node, i)]: - if isinstance(shpnode.op, Shape_i): - idx = shpnode.op.i - repl = self.shape_of[new_r][idx] - if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. - continue - if ( - repl.owner - and repl.owner.inputs[0] is shpnode.inputs[0] - and isinstance(repl.owner.op, Shape_i) - and repl.owner.op.i == shpnode.op.i - ): - # The replacement is a shape_i of the same - # input. So no need to do this equivalent - # replacement. - continue + if r is new_r: + return + # Defer invalidation: the next query flushes this node and its downstream cone. + self._stale.add(node) - if shpnode.outputs[0] in ancestors([repl]): - raise InconsistencyError( - "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" - ) - - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, - # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] - for k in unscheduled: - del self.scheduled[k] - - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): - # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + # Schedule Shape_i(r) replacements for local_track_shape_i + if hasattr(r.type, "ndim"): + for shpnode, _idx in fgraph.clients.get(r, []): + if isinstance(getattr(shpnode, "op", None), Shape_i): + self.scheduled[shpnode] = new_r def same_shape( self, @@ -666,63 +336,27 @@ def same_shape( dim_x: int | None = None, dim_y: int | None = None, ) -> bool: - """Return ``True`` if `x` and `y` have the same shape. - - Parameters - ========== - x - The `Variable` for which its shape is to be compared with `y`'s shape. - y - The `Variable` for which its shape is to be compared with `x`'s shape. - dim_x - If non ``None``, compare only the dimension of `x` equal to - `dim_x`. - dim_y - If non ``None``, compare only the dimension of `y` equal to - `dim_y`. - - """ - sx = self.shape_of[x] - sy = self.shape_of[y] - + """Return True if we can statically prove x and y have the same shape.""" + sx = self.shape_tuple(x) + sy = self.shape_tuple(y) if sx is None or sy is None: return False - if dim_x is not None: - sx = [sx[dim_x]] - + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] - + sy = (sy[dim_y],) if len(sx) != len(sy): return False - - # Canonicalize the graphs so that comparisons are reasonable - # TODO FIXME: This should *not* need to be performed manually here. - # Instead, the shape information in `self.shape_of` should be operated - # upon alongside all the other elements in a `FunctionGraph` (e.g. as - # if `self.shape_of.values()` were additional outputs). - shapes_fg = FunctionGraph( - outputs=sx + sy, - # features=[self], - clone=True, - # copy_inputs=False, - ) - from pytensor.graph.rewriting.utils import rewrite_graph - - canon_shapes_fg = type_cast( - FunctionGraph, - rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), - ) - canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy, strict=True): - if not equal_computations([dx], [dy]): + if dx is dy: + continue + if isinstance(dx, Constant) and isinstance(dy, Constant): + if dx.data == dy.data: + continue return False - + if equal_computations([dx], [dy]): + continue + return False return True def clone(self): @@ -1300,7 +934,7 @@ def local_shape_to_shape_i(fgraph, node): if not hasattr(fgraph, "shape_feature"): return shape_feature = fgraph.shape_feature - ret = shape_feature.make_vector_shape(node.inputs[0]) + ret = as_tensor_variable(shape_feature.shape_tuple(node.inputs[0]), dtype="int64") # We need to copy over stack trace from input to output copy_stack_trace(node.outputs[0], ret) @@ -1312,44 +946,40 @@ def local_shape_to_shape_i(fgraph, node): @register_canonicalize @node_rewriter([Shape_i]) def local_track_shape_i(fgraph, node): - """ - Update `Shape_i` nodes to match `ShapeFeature`'s internal state. - - This rewrite is essential for propagating shape information during graph - transformations (like lowering). When a node is replaced or updated, - `ShapeFeature` calculates the shape of the new node and "schedules" - dependent `Shape_i` nodes for update, so they use the latest inferred graph. - - If we start with an fgraph containing the two nodes below: - >> out = OpWithoutInferShape(a, b) - >> out_shape_i = Shape_i(out) + """Replace ``Shape_i(v, i)`` with the inferred shape expression. - And then rewrite - >> new_out = OpWithInferShape(a, b) - >> fgraph.replace(out, new_out) - - We end up with - >> out_shape_i == Shape_i(new_out) + When ``v.owner.op`` has ``infer_shape``, ``get_shape(v, i)`` returns + a non-``Shape_i`` expression. Rewriting the literal ``Shape_i(v, i)`` + with that expression lets downstream rewrites see the inferred form + and typically lets the original producer of ``v`` be pruned when only + its shape is consumed. + """ + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is None: + return False - If installed, ShapeFeature will do this work in the background - >> new_out_shape = infer_shape(new_out) # Usually some f(a, b) - >> fgraph.shape_feature.scheduled[out_shape_i.owner] = new_out_shape + # Handle scheduled replacements from on_change_input + replacement = shape_feature.scheduled.pop(node, None) + if replacement is not None: + return [shape_feature.get_shape(replacement, node.op.i)] - And this rewrite will ultimately propagate the inference back to the fgraph - >> new_out_shape_i = fgraph.shape_feature.scheduled[out_shape_i.owner][i] - >> fgraph.replace(out_shape_i, new_out_shape_i) + [v] = node.inputs + if v.owner is None: + return False - """ - try: - shape_feature = fgraph.shape_feature - except AttributeError: + i = node.op.i + new_shape = shape_feature.get_shape(v, i) + if new_shape is None: return False - if node not in shape_feature.scheduled: + # Avoid replacing Shape_i(v, i) with itself + if new_shape.owner is node or ( + isinstance(new_shape, Variable) + and new_shape.owner is not None + and isinstance(new_shape.owner.op, Shape_i) + and new_shape.owner.op.i == i + and new_shape.owner.inputs[0] is v + ): return False - # Don't unschedule node as it could be reinserted in the - # fgraph as we don't change it in the shapefeature internal - # structure. - replacement = shape_feature.scheduled[node] - return [shape_feature.shape_of[replacement][node.op.i]] + return [new_shape] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index ef51cae306..e045648816 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -883,12 +883,12 @@ def _local_subtensor_merge_rewrite(fgraph, node, *, merge_integer_index): indices_outer = unflatten_index_variables(outer_index_vars, node.op.idx_list) try: - xshape = fgraph.shape_feature.shape_of[x] + xshape = fgraph.shape_feature.shape_tuple(x) except AttributeError: xshape = tuple(x.shape) try: - ushape = fgraph.shape_feature.shape_of[u] + ushape = fgraph.shape_feature.shape_tuple(u) except AttributeError: ushape = tuple(u.shape) @@ -1201,7 +1201,7 @@ def local_useless_subtensor(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature cdata = get_constant_idx( node.op.idx_list, @@ -1223,7 +1223,7 @@ def local_useless_subtensor(fgraph, node): # is not a useless subtensor return False - length_pos = shape_of[node.inputs[0]][pos] + length_pos = shape_feature.get_shape(node.inputs[0], pos) if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize @@ -1327,12 +1327,12 @@ def local_useless_AdvancedSubtensor1(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature # get length of the indexed tensor along the first axis try: length = get_scalar_constant_value( - shape_of[node.inputs[0]][0], only_process_constants=True + shape_feature.get_shape(node.inputs[0], 0), only_process_constants=True ) except NotScalarConstantError: return False @@ -2417,7 +2417,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): # need it for this optimization, so don't continue. return False - shape_of = shape_feature.shape_of same_shape = shape_feature.same_shape # Get the subtensor of `x` indexed by `i` in order to compare @@ -2431,22 +2430,12 @@ def local_useless_inc_subtensor_alloc(fgraph, node): else: raise Exception("Should never happen!") - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - # `xi` may have more dimensions than `y` since the subtensor ops # do automatic broadcasting of the increment internally. Thus, we # need to make the leading implicitly broadcasted dimensions # explicit for shape comparison later. if xi.ndim > y.ndim: y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") # Build `z_broad` explicitly to include extra implicit dimensions. z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable @@ -2479,7 +2468,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): if ( z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 + and shape_feature.get_shape(y, k) != 1 ) ] diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 74c445f2f5..1b43eaec6e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -339,21 +339,7 @@ def shape_i(var, i, fgraph=None): """ if fgraph and hasattr(fgraph, "shape_feature"): - shape_feature = fgraph.shape_feature - shape_of = shape_feature.shape_of - - def recur(node): - if node.outputs[0] not in shape_of: - for inp in node.inputs: - if inp.owner: - recur(inp.owner) - # If the output var isn't marked as being in the graph, - # we need to add it in the ShapeFeature. - shape_feature.on_import(fgraph, node, "graph.ops.shape_i") - - if var not in shape_of: - recur(var.owner) - return shape_of[var][i] + return fgraph.shape_feature.get_shape(var, i) # If we are not able to use the shape feature, we should not put # Shape_i in the graph. Otherwise, the shape feature optimization diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 1a7e681c22..8662faded3 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -83,11 +83,16 @@ def shape_of_variables( shape_feature = fgraph.shape_feature input_dims = [ - dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp] + dimension + for inp in fgraph.inputs + for dimension in shape_feature.shape_tuple(inp) ] output_dims = [ - dimension for shape in shape_feature.shape_of.values() for dimension in shape + dimension + for var in fgraph.variables + if hasattr(var.type, "ndim") + for dimension in shape_feature.shape_tuple(var) ] compute_shapes = pytensor.function(input_dims, output_dims) @@ -105,8 +110,10 @@ def shape_of_variables( sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) l = {} - for var in shape_feature.shape_of: - l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var]) + for var in fgraph.variables: + shape = shape_feature.shape_tuple(var) + if shape is not None: + l[var] = tuple(sym_to_num_dict[sym] for sym in shape) return l diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index f8180773cc..c6f1f1e74c 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -460,8 +460,8 @@ def test_infer_shape(self): fg = FunctionGraph(outputs=[op_var[1]], clone=False) opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer()) - assert opt_res.shape_feature.shape_of[x] is None - assert opt_res.shape_feature.shape_of[z][0].data == 2 + assert opt_res.shape_feature.shape_tuple(x) is None + assert opt_res.shape_feature.shape_tuple(z)[0].data == 2 def test_make_node_shared(self): """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..96cecdc333 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -284,7 +284,7 @@ def test_normal_ShapeFeature(): clone=False, features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt, sd_pt], [s1, s2, d_rv], mode=py_mode, on_unused_input="ignore") s1_val, s2_val, d_rv_val = f(3, np.array(1.0, dtype=config.floatX)) @@ -657,7 +657,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt], [s1, s2], mode=py_mode) s1_val, s2_val = f(2) @@ -679,7 +679,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv] + s1, s2, s3, s4 = fg.shape_feature.shape_tuple(d_rv) mean_val = np.array([[0, 1, 2]], dtype=config.floatX) f = function([mean, cov], [s1, s2, s3, s4], mode=py_mode, on_unused_input="ignore") @@ -810,7 +810,7 @@ def test_dirichlet_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) assert M_pt in graph_inputs([s1]) assert N_pt in graph_inputs([s2]) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index f017d539fe..1b8696dd66 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -17,7 +17,7 @@ from pytensor.graph.type import Type from pytensor.tensor.basic import alloc, as_tensor_variable from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import add, exp, maximum +from pytensor.tensor.math import add, cos, exp, maximum, sin from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.shape import ( ShapeFeature, @@ -613,6 +613,30 @@ def test_vector_dim_err(self): shape_feature.same_shape(x, o, 0, 1) +def test_get_shape_resolves_through_chain(): + """get_shape should resolve to the deepest input, not intermediate ops.""" + x = matrix("x") + w = matrix("w") + inner = cos(x) + y = sin(exp(inner.T)) + + fg = FunctionGraph([x, w], [y], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + + s = sf.get_shape(y, 0) + utt.assert_equal_computations([s], [Shape_i(1)(x)]) + + # Changing an input invalidates the cached shape of the changed node and of + # every node downstream of it, not just the node whose input changed. Here + # ``inner`` feeds the transpose, with ``exp`` and ``sin`` further downstream, + # so the re-query must resolve through ``w`` rather than returning the stale + # ``x``-based shape held by the downstream nodes' caches. + fg.replace(inner, cos(w)) + s = sf.get_shape(y, 0) + utt.assert_equal_computations([s], [Shape_i(1)(w)]) + + def test_useless_specify_shape(): x = tensor("x", shape=(None, 5, 3)) diff --git a/tests/xtensor/test_rewriting.py b/tests/xtensor/test_rewriting.py index da076b1824..2da37fa919 100644 --- a/tests/xtensor/test_rewriting.py +++ b/tests/xtensor/test_rewriting.py @@ -17,8 +17,21 @@ def test_infer_shape_db_handles_xtensor_lowering(): [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [(x.values.sum(0)).shape[0]]) - # With ShapeFeature - fgraph = FunctionGraph([x], [shape_y], features=[ShapeFeature()], copy_inputs=False) + # With ShapeFeature — force caching shape of XRV output before lowering + sf = ShapeFeature() + fgraph = FunctionGraph([x], [shape_y], features=[sf], copy_inputs=False) + # Force get_shape on the XRV sum output (y) before any rewriting lowers it. + # This caches a shape expression referencing the XRV variable. + y_in_graph = [ + v + for v in fgraph.variables + if hasattr(v.type, "ndim") and v.type.ndim == 1 and v is not x + ] + for v in y_in_graph: + try: + sf.get_shape(v, 0) + except Exception: + pass infer_shape_db.default_query.rewrite(fgraph) [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [Shape_i(1)(x)]) From 7041e88e6afdc0705a11b31214e923d985d5d2e4 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 28 May 2026 14:57:56 +0200 Subject: [PATCH 05/10] Cache OFG shape graph --- pytensor/compile/builders.py | 64 ++++++++++++++++++++---------- tests/benchmarks/test_rewriting.py | 51 ++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 20 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 910b8d842d..8cf67c5190 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -865,29 +865,53 @@ def connection_pattern(self, node): return ret def infer_shape(self, node, shapes): - # TODO: Use `fgraph.shape_feature` to do this instead. - out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes) - - # Clone the output shape so that shape are computed from outer inputs. - # Note: - # Here we could do it more simply like: - # `ret = [pytensor.clone_replace(shp, replace=repl) for shp in out_shp]` - # But doing it multiple time could duplicate common subgraph between - # each shape call. PyTensor optimizer will clean this up later, but this - # will make extra work for the optimizer. - - repl = dict(zip(self.inner_inputs, node.inputs, strict=True)) - clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] - cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) + try: + template = self._inner_shape_template + frozen = self._inner_shape_frozen + shape_i_keys = self._inner_shape_i_keys + except AttributeError: + from pytensor.tensor.rewriting.shape import ShapeFeature + + sf = ShapeFeature() + inner_inputs = self.inner_inputs + template = [sf.shape_tuple(o) for o in self.inner_outputs] + flat_shapes = [s for tup in template if tup is not None for s in tup] + + # Symbolic dims are Shape_i(inner_input, dim) vars. Expose them as graph + # inputs (which act as blockers) so bind can swap them for the caller's + # shapes; inner_inputs are exposed too so Ops returning inputs as shape + # components get the caller's inputs. + shape_i_vars = [] + shape_i_keys = [] + for i, inp in enumerate(inner_inputs): + per_dim = sf._shape_i_cache.get(inp) + if per_dim is None: + continue + for dim, var in per_dim.items(): + shape_i_vars.append(var) + shape_i_keys.append((i, dim)) + + frozen = FrozenFunctionGraph([*inner_inputs, *shape_i_vars], flat_shapes) + self._inner_shape_template = template + self._inner_shape_frozen = frozen + self._inner_shape_i_keys = shape_i_keys + + replacements = list(node.inputs) + for i, dim in shape_i_keys: + shp = shapes[i] + replacements.append(node.inputs[i].shape[dim] if shp is None else shp[dim]) + + bound_shapes = frozen.bind(dict(zip(frozen.inputs, replacements, strict=True))) + ret = [] - used = 0 - for i, out_shape in enumerate(out_shapes): - if out_shape is None: + idx = 0 + for tup in template: + if tup is None: ret.append(None) else: - nb = len(out_shape) - ret.append(cloned[used : used + nb]) - used += nb + nb = len(tup) + ret.append(bound_shapes[idx : idx + nb]) + idx += nb return ret diff --git a/tests/benchmarks/test_rewriting.py b/tests/benchmarks/test_rewriting.py index 5b0b86ba6f..7303d230de 100644 --- a/tests/benchmarks/test_rewriting.py +++ b/tests/benchmarks/test_rewriting.py @@ -1,9 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt +import pytensor.xtensor as px from pytensor import config from pytensor.graph import FunctionGraph +from pytensor.graph.rewriting import rewrite_graph +from pytensor.xtensor.shape import stack as xstack def _large_fuseable_graph(n): @@ -66,3 +70,50 @@ def rewrite_func(): assert rewrite_func() == expected_n_repl benchmark.pedantic(rewrite_func, rounds=7, iterations=5) + + +def _xtensor_attention_graph(n_layers): + B, T, E, H, HD = 4, 32, 64, 4, 16 + rng = np.random.default_rng(0) + + def attn(x): + Wqkv = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, 3, H, HD))), + dims=("embd", "qkv", "head", "hd"), + ) + Wproj = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, E))), + dims=("embd", "embd_out"), + ) + qkv = px.dot(x, Wqkv, dim="embd") + q = qkv.isel(qkv=0).rename(time="time_q") + k = qkv.isel(qkv=1).rename(time="time_k") + v = qkv.isel(qkv=2).rename(time="time_k") + s = px.dot(q, k, dim="hd") / np.sqrt(HD) + mask = px.as_xtensor( + pt.tril(pt.ones((T, T), dtype="bool")), + dims=("time_q", "time_k"), + ) + a = px.math.softmax(px.where(mask, s, np.float64(-1e9)), dim="time_k") + o = xstack(px.dot(a, v, dim="time_k"), embd=("head", "hd")) + return px.dot(o, Wproj, dim="embd").rename(time_q="time", embd_out="embd") + + x_t = pt.tensor("x", shape=(B, T, E)) + x = px.as_xtensor(x_t, dims=("batch", "time", "embd")) + for _ in range(n_layers): + x = attn(x) + return x_t, x.values.sum() + + +@pytest.mark.parametrize("n_layers", [2, 3, 4]) +def test_xtensor_attention_rewrite_benchmark(n_layers, benchmark): + x_t, loss = _xtensor_attention_graph(n_layers) + + def rewrite_once(): + lowered = rewrite_graph(loss, include=("lower_xtensor",), clone=True) + grad = pt.grad(lowered, x_t) + return rewrite_graph( + [lowered, grad], include=("fast_run",), exclude=("inplace",), clone=True + ) + + benchmark(rewrite_once) From 4ba5572e8e32cd9fe55988cd8fd70c660b6be1d4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 29 May 2026 16:08:17 +0200 Subject: [PATCH 06/10] Build OFG shape graph via FrozenFunctionGraph.from_structural_inputs Add `FrozenFunctionGraph.from_structural_inputs`, the structural-matching dual of `bind`: inputs may be interior expressions, matched against the outputs by structure (via interning) and rewired to the input boundary. `OpFromGraph.infer_shape` now uses it to express the inner-output shapes as a frozen function of the inner inputs plus one slot per input dimension, replacing the manual blocker list and `shape_i_keys` bookkeeping. The dense, positional layout lets `bind` fill slots straight from the caller's shapes. --- pytensor/compile/builders.py | 40 +++++++++++----------- pytensor/graph/fg.py | 31 +++++++++++++++++ tests/graph/test_fg.py | 65 +++++++++++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 22 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 8cf67c5190..123f64bf0d 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -26,6 +26,7 @@ from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError +from pytensor.tensor.shape import Shape_i def infer_shape(outs, inputs, input_shapes): @@ -868,7 +869,6 @@ def infer_shape(self, node, shapes): try: template = self._inner_shape_template frozen = self._inner_shape_frozen - shape_i_keys = self._inner_shape_i_keys except AttributeError: from pytensor.tensor.rewriting.shape import ShapeFeature @@ -877,31 +877,29 @@ def infer_shape(self, node, shapes): template = [sf.shape_tuple(o) for o in self.inner_outputs] flat_shapes = [s for tup in template if tup is not None for s in tup] - # Symbolic dims are Shape_i(inner_input, dim) vars. Expose them as graph - # inputs (which act as blockers) so bind can swap them for the caller's - # shapes; inner_inputs are exposed too so Ops returning inputs as shape - # components get the caller's inputs. - shape_i_vars = [] - shape_i_keys = [] - for i, inp in enumerate(inner_inputs): - per_dim = sf._shape_i_cache.get(inp) - if per_dim is None: - continue - for dim, var in per_dim.items(): - shape_i_vars.append(var) - shape_i_keys.append((i, dim)) - - frozen = FrozenFunctionGraph([*inner_inputs, *shape_i_vars], flat_shapes) + # Express the inner-output shapes as a frozen function of the inner + # inputs plus each input's per-dim size. from_structural_inputs rewires + # every Shape_i(inner_input, dim) occurrence to the matching input, so + # bind can later swap in the caller's shapes. One slot per input dim: + # static or unused dims become dead inputs, keeping the layout positional. + shape_inputs = [ + Shape_i(dim)(inp) + for inp in inner_inputs + for dim in range(getattr(inp.type, "ndim", 0)) + ] + frozen = FrozenFunctionGraph.from_structural_inputs( + [*inner_inputs, *shape_inputs], flat_shapes + ) self._inner_shape_template = template self._inner_shape_frozen = frozen - self._inner_shape_i_keys = shape_i_keys + # frozen.inputs is [*inner_inputs, *per-dim sizes]; mirror that layout. replacements = list(node.inputs) - for i, dim in shape_i_keys: - shp = shapes[i] - replacements.append(node.inputs[i].shape[dim] if shp is None else shp[dim]) + for shp in shapes: + if shp is not None: + replacements.extend(shp) - bound_shapes = frozen.bind(dict(zip(frozen.inputs, replacements, strict=True))) + bound_shapes = frozen.bind(replacements) ret = [] idx = 0 diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 4b4250283f..b43c3cdd8b 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1031,6 +1031,37 @@ def _resolve_input(inp, memo=memo): self._variables: frozenset[Variable] | None = None self._clients: dict[Variable, list[ClientType]] | None = None + @classmethod + def from_structural_inputs( + cls, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + ) -> "FrozenFunctionGraph": + """Freeze ``outputs``, allowing ``inputs`` to be *interior* expressions. + + Structural-matching dual of `bind`: where `bind` maps inputs to values, + this lifts chosen sub-expressions up to inputs. An ``input`` produced by + an `Apply` is matched against ``outputs`` by structure (not identity) and + every occurrence is rewired to it; root inputs behave as in the + constructor. Intermediate inputs absent from ``outputs`` become dead + inputs. The signature preserves the order of ``inputs``. ``outputs`` must + be computable from ``inputs`` alone — any root they still depend on + directly must itself appear in ``inputs``, else the rewired graph is + orphaned. + """ + # Discover the true graph roots (as FunctionGraph(inputs=None) does) to + # seed the staged freeze; the caller's `inputs` may be intermediate. + # Freezing inputs and outputs together interns each intermediate input + # onto the same object as its occurrences in the outputs, so the + # re-freeze can rewire them — which requires intermediate inputs to be + # *built*, not blocked, hence only roots seed the freeze. + roots = [ + v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant) + ] + interned = cls(roots, [*inputs, *outputs]) + n_inputs = len(inputs) + return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:]) + def __reduce__(self): return FrozenFunctionGraph, (self.inputs, self.outputs) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index cd91fd8ca4..4dad31763e 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -4,7 +4,7 @@ import pytest from pytensor.configdefaults import config -from pytensor.graph.basic import NominalVariable +from pytensor.graph.basic import NominalVariable, equal_computations from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -999,3 +999,66 @@ def test_bind_constant_output(self): bound = ffg.bind({ffg.inputs[0]: y}) assert len(bound) == 2 assert bound[1] is c + + def test_from_structural_inputs_only_root_inputs(self): + """All inputs are roots: behaves like the plain constructor.""" + x, y = float64("x"), float64("y") + out = add(x, y) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y], [out]) + assert len(ffg.inputs) == 2 + + a, b = float64("a"), float64("b") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b], strict=True))) + assert equal_computations([res], [add(a, b)]) + + def test_from_structural_inputs_only_intermediate_inputs(self): + """Inputs may be only intermediate expressions; roots are found automatically.""" + x, y = float64("x"), float64("y") + # out depends on x, y only through the product. + out = add(mul(x, y), mul(x, y)) + + # The passed expression is matched by structure, not identity. + prod = mul(x, y) + assert prod is not out.owner.inputs[0] + + ffg = FrozenFunctionGraph.from_structural_inputs([prod], [out]) + assert len(ffg.inputs) == 1 + + p = float64("p") + [res] = ffg.bind({ffg.inputs[0]: p}) + # Both occurrences rewire to the single input. + assert equal_computations([res], [add(p, p)]) + + def test_from_structural_inputs_mixed_inputs(self): + """A root input and an intermediate input, both live.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, mul(x, y)], [out]) + assert len(ffg.inputs) == 2 + + a, p = float64("a"), float64("p") + # x is used directly; root y is dropped (it feeds only the lifted product). + [res] = ffg.bind(dict(zip(ffg.inputs, [a, p], strict=True))) + assert equal_computations([res], [add(p, a)]) + + def test_from_structural_inputs_dead_inputs(self): + """A dead root input and a dead intermediate input are retained but ignored.""" + x, y = float64("x"), float64("y") + out = add(x, x) # uses neither y nor the product + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y, mul(x, y)], [out]) + assert len(ffg.inputs) == 3 + + a, b, p = float64("a"), float64("b"), float64("p") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b, p], strict=True))) + assert equal_computations([res], [add(a, a)]) + + def test_from_structural_inputs_unreachable_output_raises(self): + """Outputs needing a root absent from the inputs cannot be expressed.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) # needs x directly, not only via the product + + with pytest.raises(ValueError): + FrozenFunctionGraph.from_structural_inputs([mul(x, y)], [out]) From 55f46f6e61d5600beba12f13cc81e3c085885cab Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 11 Jun 2026 04:35:06 +0200 Subject: [PATCH 07/10] Move numba core shape test next to the numba rewrite tests --- pytensor/tensor/random/rewriting/numba.py | 2 +- pytensor/tensor/rewriting/numba.py | 22 +++----- tests/tensor/rewriting/test_blockwise.py | 50 +--------------- tests/tensor/rewriting/test_numba.py | 69 +++++++++++++++++++++++ 4 files changed, 80 insertions(+), 63 deletions(-) create mode 100644 tests/tensor/rewriting/test_numba.py diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 8d128ec698..894a08b893 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -70,7 +70,7 @@ def introduce_explicit_core_shape_rv(fgraph, node): else: core_shape = as_tensor(core_shape) - [core_shape] = simplify_core_shape_graphs([core_shape]) + [core_shape] = simplify_core_shape_graphs([core_shape], fgraph) new_outs = ( RandomVariableWithCoreShape( diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 0a6d59c4be..eb508a0275 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -7,22 +7,18 @@ from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.rewriting.shape import ShapeFeature -from pytensor.tensor.shape import Shape, Shape_i -def simplify_core_shape_graphs(core_shapes): - """Simplify core shape expressions by canonicalizing shape arithmetic. +def simplify_core_shape_graphs(core_shapes, fgraph): + """Canonicalize the fresh shape arithmetic built by infer_shape. - Temporarily detaches Shape/Shape_i outputs from their owners so that - rewrite_graph operates only on the arithmetic above them (e.g. - constant-folding static shapes, eliminating dead Switch branches). + The rewrite is in place, so ``fgraph`` variables are detached to avoid + mutating it behind its features' backs. """ - shape_boundary = [ - var - for var in ancestors(core_shapes) - if var.owner is not None and isinstance(var.owner.op, (Shape, Shape_i)) - ] - saved_owners = [(v, v.owner, v.index) for v in shape_boundary] + graph_boundary = fgraph.variables.intersection( + ancestors(core_shapes, blockers=fgraph.variables) + ) + saved_owners = [(v, v.owner, v.index) for v in graph_boundary] for v, _, _ in saved_owners: v.owner = None try: @@ -121,7 +117,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): # If Blockwise shows up in the shape graph we can't introduce the core shape return None - core_shapes = simplify_core_shape_graphs(core_shapes) + core_shapes = simplify_core_shape_graphs(core_shapes, fgraph) return BlockwiseWithCoreShape( [*node.inputs, *core_shapes], diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index 4e40507158..c511e21f66 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -9,15 +9,11 @@ from pytensor.graph.traversal import apply_ancestors from pytensor.scalar import log as scalar_log from pytensor.tensor import add, alloc, iscalar, matrix, scalar, tensor, tensor3 -from pytensor.tensor.basic import MakeVector, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.linalg.inverse import MatrixPinv -from pytensor.tensor.math import add as tensor_add -from pytensor.tensor.math import maximum, minimum from pytensor.tensor.rewriting.blockwise import local_useless_blockwise -from pytensor.tensor.shape import Reshape, Shape_i -from pytensor.tensor.signal.conv import Convolve1d +from pytensor.tensor.shape import Reshape def test_useless_blockwise_of_elemwise(): @@ -190,47 +186,3 @@ def test_blockwise_reshape(): new_y.eval({"x": test_x}, mode=no_rewrites), rewritten_y.eval({"x": test_x}, mode=no_rewrites), ) - - -@pytest.mark.parametrize( - "mode, x_shape, k_shape", - [ - ("valid", (3, 10), (3, 4)), - ("full", (3, 10), (3, 4)), - ("valid", (None, None), (None, None)), - ("full", (None, None), (None, None)), - ], -) -def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): - x = tensor("x", shape=x_shape) - k = tensor("k", shape=k_shape) - out = Blockwise(Convolve1d())( - x, k, constant(mode == "full", dtype="bool"), return_list=True - ) - fn = function([x, k], out, mode="NUMBA") - - [bwcs_node] = [ - n - for n in fn.maker.fgraph.apply_nodes - if isinstance(n.op, BlockwiseWithCoreShape) - ] - core_shape = bwcs_node.inputs[-1] - - static = all(s is not None for s in x_shape + k_shape) - if static: - n, kk = x_shape[1], k_shape[1] - expected_len = (n + kk - 1) if mode == "full" else (n - kk + 1) - expected = constant(np.array([expected_len])) - assert equal_computations([core_shape], [expected]) - else: - n = Shape_i(1)(x) - kk = Shape_i(1)(k) - if mode == "full": - expected = MakeVector("int64")( - tensor_add(constant(-1, dtype="int64"), n, kk) - ) - else: - expected = MakeVector("int64")( - constant(1, dtype="int64") + maximum(n, kk) - minimum(n, kk) - ) - assert equal_computations([core_shape], [expected], in_xs=[x, k], in_ys=[x, k]) diff --git a/tests/tensor/rewriting/test_numba.py b/tests/tensor/rewriting/test_numba.py new file mode 100644 index 0000000000..8267f1f728 --- /dev/null +++ b/tests/tensor/rewriting/test_numba.py @@ -0,0 +1,69 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import function +from pytensor.compile import optdb +from pytensor.compile.aliasing import add_supervisor_to_fgraph +from pytensor.compile.io import In +from pytensor.compile.mode import get_mode +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.basic import equal_computations +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.basic import MakeVector, alloc, constant +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape +from pytensor.tensor.math import add as tensor_add +from pytensor.tensor.math import maximum, minimum +from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.signal import convolve1d +from pytensor.tensor.signal.conv import Convolve1d + + +def count_ops(fgraph, op_type): + return sum(isinstance(node.op, op_type) for node in fgraph.apply_nodes) + + +@pytest.mark.parametrize( + "mode, x_shape, k_shape", + [ + ("valid", (3, 10), (3, 4)), + ("full", (3, 10), (3, 4)), + ("valid", (None, None), (None, None)), + ("full", (None, None), (None, None)), + ], +) +def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): + x = pt.tensor("x", shape=x_shape) + k = pt.tensor("k", shape=k_shape) + out = Blockwise(Convolve1d())( + x, k, constant(mode == "full", dtype="bool"), return_list=True + ) + fn = function([x, k], out, mode="NUMBA") + + [bwcs_node] = [ + n + for n in fn.maker.fgraph.apply_nodes + if isinstance(n.op, BlockwiseWithCoreShape) + ] + core_shape = bwcs_node.inputs[-1] + + static = all(s is not None for s in x_shape + k_shape) + if static: + n, kk = x_shape[1], k_shape[1] + expected_len = (n + kk - 1) if mode == "full" else (n - kk + 1) + expected = constant(np.array([expected_len])) + assert equal_computations([core_shape], [expected]) + else: + n = Shape_i(1)(x) + kk = Shape_i(1)(k) + if mode == "full": + expected = MakeVector("int64")( + tensor_add(constant(-1, dtype="int64"), n, kk) + ) + else: + expected = MakeVector("int64")( + constant(1, dtype="int64") + maximum(n, kk) - minimum(n, kk) + ) + assert equal_computations([core_shape], [expected], in_xs=[x, k], in_ys=[x, k]) From a99b1e0c451ceffcbd54273c586882b0d5dc2892 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 11 Jun 2026 16:59:37 +0200 Subject: [PATCH 08/10] Constant-fold Subtensor of Shape for any static dim local_subtensor_shape_constant only folded broadcastable dims, deferring the general case to the ShapeFeature, which is not present everywhere canonicalize runs (e.g. rewrite_subgraph). --- pytensor/tensor/rewriting/subtensor_lift.py | 16 +++++----------- tests/tensor/rewriting/test_subtensor.py | 3 ++- tests/tensor/rewriting/test_subtensor_lift.py | 11 +++++++++++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index e170c2c61b..59dc8eb7dc 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1104,12 +1104,6 @@ def local_subtensor_shape_constant(fgraph, node): TensorConstant{1} - TODO: Something like `local_shape_to_shape_i` should be a general - canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were - the case, we could change this to only operate on `Shape_i`\s. - Currently, we're not handling them because they should only appear when - `ShapeFeature` is present, and it will also simplify/remove them. - """ shape = node.inputs[0] @@ -1130,7 +1124,7 @@ def local_subtensor_shape_constant(fgraph, node): return False try: - shape_parts = shape_arg.type.broadcastable[idx_val] + shape_parts = shape_arg.type.shape[idx_val] except IndexError: # An out-of-bounds index here is an error in the source graph # (e.g. ``scalar.shape[0]``), but it should fail at runtime rather @@ -1138,10 +1132,10 @@ def local_subtensor_shape_constant(fgraph, node): return False if isinstance(shape_parts, Iterable): - if all(shape_parts): - return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] - elif shape_parts: - return [as_tensor(1, dtype=np.int64)] + if all(s is not None for s in shape_parts): + return [as_tensor(list(shape_parts), dtype=np.int64, ndim=1)] + elif shape_parts is not None: + return [as_tensor(shape_parts, dtype=np.int64)] @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index eb0a34f0e2..4f7503d1d7 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -204,11 +204,12 @@ def test_local_useless_inc_subtensor_no_opt(): result.assert_eval([[1, 2], [3, 4]], [[10, 20], [30, 40]]) # Increment with a non-zero constant target array, same collapse to x + y. + # ``ones`` has a static shape, so ``ones.shape`` folds to the constants (2, 2). ones = pt.ones((2, 2)) s = ones[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) result = utt.RewriteTester([y], [o_shape]) - result.assert_graph(ones + specify_shape(y, ones.shape)) + result.assert_graph(ones + specify_shape(y, (2, 2))) result.assert_eval([[10, 20], [30, 40]]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 56dd04a2b9..b245f53237 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -1060,6 +1060,17 @@ def test_local_subtensor_shape_constant(): assert isinstance(res, Constant) assert np.array_equal(res.data, [1, 1]) + # Any static dim folds, not just broadcastable ones + x = tensor(dtype=np.float64, shape=(7, None)).shape[0] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert res.data == 7 + + x = _shape(tensor(dtype=np.float64, shape=(None, 3, 7)))[1:] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert np.array_equal(res.data, [3, 7]) + @pytest.mark.parametrize( "original_fn, supported", From 489f4d4b1e64a65bb123492b0dc93a27c930011b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 1 May 2026 17:58:16 +0200 Subject: [PATCH 09/10] Introduce numba core shapes from the wrapped node's own inputs Core shapes built from recursive shape inference can read variables that inplace rewrites destroy, making the wrapper node unschedulable. Expressions that read only the node's own inputs, through constants and fresh Shape_i applies, never conflict with the surrounding graph. They are canonicalized in isolation with the new rewrite_subgraph utility. --- pytensor/graph/rewriting/utils.py | 47 +++++++++++- pytensor/tensor/random/rewriting/numba.py | 28 +++++--- pytensor/tensor/rewriting/numba.py | 48 +++++-------- pytensor/tensor/rewriting/shape.py | 88 ++++++++++++++++++----- tests/tensor/rewriting/test_numba.py | 65 +++++++++++++++-- 5 files changed, 210 insertions(+), 66 deletions(-) diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index 26a2970408..ea8d8cfd0d 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -1,6 +1,6 @@ import copy -from collections.abc import Generator, Sequence -from typing import TYPE_CHECKING, Optional +from collections.abc import Generator, Iterable, Sequence +from typing import TYPE_CHECKING, Optional, cast import pytensor from pytensor.graph.basic import ( @@ -62,6 +62,49 @@ def rewrite_graph( return fgraph.outputs +def rewrite_subgraph( + outputs: Sequence[Variable], + frontier: Iterable[Variable], + include: Sequence[str] = ("canonicalize",), + **kwargs, +) -> list[Variable]: + """Rewrite the subgraph between ``frontier`` and ``outputs`` in isolation. + + The ``frontier`` variables are temporarily detached from their owners, so + they act as inputs of the subgraph: rewrites can neither reach past them + nor modify the graph they belong to. This allows simplifying fresh + expressions that hang off the variables of an existing `FunctionGraph` + without mutating it behind its (and its features') back. + + The rewrite is in place: ``outputs`` must not belong to a `FunctionGraph`. + + Parameters + ---------- + outputs + The outputs of the subgraph to rewrite. + frontier + Variables at which the subgraph stops; every path from ``outputs`` + into the surrounding graph must go through one of them. + include + Rewrite query names, as in `rewrite_graph`. + **kwargs + Keyword arguments passed to `rewrite_graph`. + """ + saved_owners = [(v, v.owner, v.index) for v in frontier] + for v, _, _ in saved_owners: + v.owner = None + try: + rewritten = cast( + Sequence[Variable], + rewrite_graph(list(outputs), include=include, clone=False, **kwargs), + ) + return list(rewritten) + finally: + for v, owner, idx in saved_owners: + v.owner = owner + v.index = idx + + def is_same_graph_with_merge( var1: Variable, var2: Variable, diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 894a08b893..0458929ee1 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -1,6 +1,7 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter +from pytensor.graph.traversal import applys_between from pytensor.tensor import as_tensor, constant from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.rewriting.numba import simplify_core_shape_graphs @@ -15,9 +16,10 @@ def introduce_explicit_core_shape_rv(fgraph, node): that has an extra "non-functional" input that represents the core shape of the random variable. This core_shape is used by the numba backend to pre-allocate the output array. - If available, the core shape is extracted from the shape feature of the graph, - which has a higher chance of having been simplified, optimized, constant-folded. - If missing, we fall back to the op._supp_shape_from_params method. + The core shape is built from ``ShapeFeature.get_non_recursive_shape``, whose + expressions read only the node's own inputs and can therefore be introduced + after inplacing without conflicting with the destroyers in the surrounding + graph. This rewrite is required for the numba backend implementation of RandomVariable. @@ -58,18 +60,26 @@ def introduce_explicit_core_shape_rv(fgraph, node): _next_rng, rv = node.outputs shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) - if shape_feature: - core_shape = [ - shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) - ] - else: - core_shape = op._supp_shape_from_params(op.dist_params(node)) + if shape_feature is None: + shape_feature = ShapeFeature() + + core_shape = [ + shape_feature.get_non_recursive_shape(rv, i) + for i in range(rv.type.ndim - op.ndim_supp, rv.type.ndim) + ] if len(core_shape) == 0: core_shape = constant([], dtype="int64") else: core_shape = as_tensor(core_shape) + if any( + isinstance(node.op, RandomVariable) + for node in applys_between(node.inputs, [core_shape]) + ): + # If the RandomVariable shows up in the shape graph we can't introduce the core shape + return None + [core_shape] = simplify_core_shape_graphs([core_shape], fgraph) new_outs = ( diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index eb508a0275..af0529869e 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -2,7 +2,7 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter -from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.rewriting.utils import rewrite_subgraph from pytensor.graph.traversal import ancestors, applys_between from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape @@ -10,26 +10,11 @@ def simplify_core_shape_graphs(core_shapes, fgraph): - """Canonicalize the fresh shape arithmetic built by infer_shape. - - The rewrite is in place, so ``fgraph`` variables are detached to avoid - mutating it behind its features' backs. - """ - graph_boundary = fgraph.variables.intersection( + """Canonicalize the fresh shape arithmetic built by ``get_non_recursive_shape``.""" + graph_frontier = fgraph.variables.intersection( ancestors(core_shapes, blockers=fgraph.variables) ) - saved_owners = [(v, v.owner, v.index) for v in graph_boundary] - for v, _, _ in saved_owners: - v.owner = None - try: - core_shapes = list( - rewrite_graph(core_shapes, include=("canonicalize",), clone=False) - ) - finally: - for v, owner, idx in saved_owners: - v.owner = owner - v.index = idx - return core_shapes + return rewrite_subgraph(core_shapes, graph_frontier, include=("canonicalize",)) @node_rewriter([Blockwise]) @@ -40,9 +25,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): that has an extra "non-functional" input that represents the core shape of the Blockwise variable. This core_shape is used by the numba backend to pre-allocate the output array. - If available, the core shape is extracted from the shape feature of the graph, - which has a higher change of having been simplified, optimized, constant-folded. - If missing, we fall back to the op._supp_shape_from_params method. + The core shape is built from ``ShapeFeature.get_non_recursive_shape``, whose + expressions read only the node's own inputs and can therefore be introduced + after inplacing without conflicting with the destroyers in the surrounding + graph. This rewrite is required for the numba backend implementation of Blockwise. @@ -94,16 +80,16 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): batch_ndim = op.batch_ndim(node) shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) - if shape_feature: - core_shapes = [ - [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] - for out in node.outputs - ] - else: - input_shapes = [tuple(inp.shape) for inp in node.inputs] - core_shapes = [ - out_shape[batch_ndim:] for out_shape in op.infer_shape(node, input_shapes) + if shape_feature is None: + shape_feature = ShapeFeature() + + core_shapes = [ + [ + shape_feature.get_non_recursive_shape(out, i) + for i in range(batch_ndim, out.type.ndim) ] + for out in node.outputs + ] core_shapes = [ as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 6369c2ad4d..982b58aea1 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -75,6 +75,8 @@ class ShapeFeature(Feature): - ``get_shape(var, i)`` / ``shape_tuple(var)`` for the shape expressed in terms of the graph inputs (recursing through the ancestors of ``var``); + - ``get_non_recursive_shape(var, i)`` for the shape expressed in terms of + ``var.owner``'s direct inputs only; - ``same_shape(x, y)`` to statically compare two shapes. Inferred shapes are cached per node, but the cache is invalidated whenever @@ -84,14 +86,15 @@ class ShapeFeature(Feature): def __init__(self): self.fgraph: FunctionGraph | None = None - # node -> tuple of (tuple of shape vars) per output, lazily populated. - # The cached shape of a node embeds the shapes of its whole ancestor cone, - # so a change anywhere above a cached node invalidates it. + # node -> tuple of (tuple of shape vars) per output, recursive shapes only, + # lazily populated. Non-recursive shapes read only the node's own inputs, so + # they are cheap and recomputed on demand rather than cached. self._cache: dict = {} - # Nodes whose input changed since the last query. Invalidation is deferred - # to the next query, which walks ``fgraph.clients`` downstream from these - # nodes; the live graph already encodes the dependencies, so no reverse map - # is kept. + # Nodes whose input changed since the last query. The recursive shape of a node + # embeds the shapes of its whole ancestor cone, so a change anywhere above a + # cached node invalidates it. We defer that to the next query and walk + # ``fgraph.clients`` downstream from these nodes then — the live graph already + # encodes the dependencies, so no reverse-dependency map is kept. self._stale: set = set() # var -> {i: Shape_i(i)(v)}, ensures Apply identity for leaves self._shape_i_cache: dict = {} @@ -115,6 +118,18 @@ def _shape_i_var(self, v, i): per_dim[i] = res return res + @staticmethod + def _fresh_shape_i(v, i): + """Like ``_shape_i_var``, but never reusing a cached variable. + + The cached variable may live in the graph with its own constraints + (e.g. its buffer destroyed by an inplace Op); a fresh read carries no + constraints beyond reading ``v``, which non-recursive consumers rely on. + """ + if hasattr(v.type, "shape") and v.type.shape[i] is not None: + return constant(v.type.shape[i], dtype="int64") + return Shape_i(i)(v) + def _coerce_shape_element(self, element, node): """Validate and normalize a single shape element from infer_shape.""" if isinstance(element, np.ndarray): @@ -151,11 +166,20 @@ def _coerce_shape_element(self, element, node): f"shape element of type {type(element).__name__}: {element!r}" ) - def _get_node_shapes(self, node): - """Return validated per-output shape tuples for ``node``.""" + def _get_node_shapes(self, node, recursive=True): + """Return validated per-output shape tuples for ``node``. + + With ``recursive=False``, input shapes are fresh constant/``Shape_i`` + reads of the node's own inputs instead of inferred expressions, so the + result references no variables beyond those inputs (and nothing is + cached). + """ if self._stale: self._flush_stale() + if not recursive: + return self._compute_node_shapes(node, recursive=False) + cache = self._cache cached = cache.get(node) if cached is not None: @@ -187,21 +211,26 @@ def _get_node_shapes(self, node): ready = False if ready: stack.pop() - cache[top] = self._compute_node_shapes(top) + cache[top] = self._compute_node_shapes(top, recursive=True) return cache[node] - def _compute_node_shapes(self, node): + def _compute_node_shapes(self, node, recursive): """Call ``infer_shape`` for ``node`` and validate the result. - The caller must have cached the input-producing nodes already, so - ``get_shape`` resolves them from the cache without recursing. + Input shapes are taken from ``get_shape`` (recursive) or from fresh + ``Shape_i`` reads of the node's own inputs (non-recursive). In the + recursive case the caller must have cached the input-producing nodes + already, so ``get_shape`` resolves them from the cache without recursing. """ + input_shape_of = self.get_shape if recursive else self._fresh_shape_i + shape_i_var = self._shape_i_var if recursive else self._fresh_shape_i + input_shapes = [] for inp in node.inputs: if hasattr(inp.type, "ndim"): input_shapes.append( - tuple(self.get_shape(inp, j) for j in range(inp.type.ndim)) + tuple(input_shape_of(inp, j) for j in range(inp.type.ndim)) ) else: input_shapes.append(None) @@ -232,9 +261,7 @@ def _compute_node_shapes(self, node): if output_shapes is not None and k < len(output_shapes): sh = output_shapes[k] if sh is None or not isinstance(sh, list | tuple): - result.append( - tuple(self._shape_i_var(out, j) for j in range(out.type.ndim)) - ) + result.append(tuple(shape_i_var(out, j) for j in range(out.type.ndim))) continue coerced = [] for j, s in enumerate(sh): @@ -259,6 +286,33 @@ def get_shape(self, var, idx): return sh[idx] return self._shape_i_var(var, idx) + def get_non_recursive_shape(self, var, idx): + """Return an expression for ``var.shape[idx]`` reading only ``var.owner``'s inputs. + + Unlike ``get_shape``, input shapes are not recursively expanded: the + expression references the direct inputs of ``var.owner`` (through + constants and fresh ``Shape_i`` reads) and nothing else. It can + therefore be introduced next to any consumer of those inputs no matter + what the destroy maps in the surrounding graph are, whereas the + recursion of ``get_shape`` may surface variables that an inplace Op + destroys. + + Works on an unattached feature. + """ + if hasattr(var.type, "shape") and var.type.shape[idx] is not None: + return constant(var.type.shape[idx], dtype="int64") + + node = var.owner + if node is None: + return self._fresh_shape_i(var, idx) + + node_shapes = self._get_node_shapes(node, recursive=False) + out_idx = node.outputs.index(var) + sh = node_shapes[out_idx] + if sh is not None: + return sh[idx] + return self._fresh_shape_i(var, idx) + def shape_tuple(self, var): if not hasattr(var.type, "ndim"): return None diff --git a/tests/tensor/rewriting/test_numba.py b/tests/tensor/rewriting/test_numba.py index 8267f1f728..b210470ce0 100644 --- a/tests/tensor/rewriting/test_numba.py +++ b/tests/tensor/rewriting/test_numba.py @@ -7,7 +7,6 @@ from pytensor.compile.aliasing import add_supervisor_to_fgraph from pytensor.compile.io import In from pytensor.compile.mode import get_mode -from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.tensor.basic import MakeVector, alloc, constant @@ -25,6 +24,21 @@ def count_ops(fgraph, op_type): return sum(isinstance(node.op, op_type) for node in fgraph.apply_nodes) +def rewrite_for_numba(inputs, outputs): + fg = FunctionGraph(inputs, outputs) + add_supervisor_to_fgraph(fg, [In(inp) for inp in fg.inputs]) + get_mode("NUMBA").optimizer.rewrite(fg) + return fg + + +def core_shape_of(fgraph): + [bwcs_node] = [ + n for n in fgraph.apply_nodes if isinstance(n.op, BlockwiseWithCoreShape) + ] + *functional_inputs, core_shape = bwcs_node.inputs + return core_shape, functional_inputs + + @pytest.mark.parametrize( "mode, x_shape, k_shape", [ @@ -42,12 +56,7 @@ def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): ) fn = function([x, k], out, mode="NUMBA") - [bwcs_node] = [ - n - for n in fn.maker.fgraph.apply_nodes - if isinstance(n.op, BlockwiseWithCoreShape) - ] - core_shape = bwcs_node.inputs[-1] + core_shape, _ = core_shape_of(fn.maker.fgraph) static = all(s is not None for s in x_shape + k_shape) if static: @@ -67,3 +76,45 @@ def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): constant(1, dtype="int64") + maximum(n, kk) - minimum(n, kk) ) assert equal_computations([core_shape], [expected], in_xs=[x, k], in_ys=[x, k]) + + +def test_introduce_core_shape_aliasing(): + """Graphs whose shape arithmetic gets inplaced, destroying variables that + recursive core shape derivations used to read; they must simply lower. + """ + larger = pt.matrix("larger", shape=(8, None)) + smaller = pt.matrix("smaller", shape=(8, None)) + a = alloc(pt.zeros((1, 1)), 1, larger.shape[1] + smaller.shape[1] - 1) + out = convolve1d(a, larger[:, ::-1], mode="full") + + fg = rewrite_for_numba([larger, smaller], [out]) + assert count_ops(fg, BlockwiseWithCoreShape) == 1 + fg.toposort() + + # Crossed variant, where each core shape mixes dimensions of both inputs + x1 = pt.matrix("x1", shape=(8, None)) + x2 = pt.matrix("x2", shape=(8, None)) + a1 = alloc(pt.zeros((1, 1)), 1, x1.shape[1] + 3) + a2 = alloc(pt.zeros((1, 1)), 1, x2.shape[1] + 5) + convA = convolve1d(a1, x2[:, ::-1], mode="full") + convB = convolve1d(a2, x1[:, ::-1], mode="full") + + fg = rewrite_for_numba([x1, x2], [convA, convB]) + assert count_ops(fg, BlockwiseWithCoreShape) == 2 + fg.toposort() + + +def test_core_shape_simplify_keeps_fgraph_intact(): + """simplify_core_shape_graphs must not rewrite the fgraph's own applies, + like the non-canonical chain feeding the alloc dim here. + """ + x = pt.matrix("x", shape=(8, None)) + a = alloc(pt.zeros((1, 1)), 1, (x.shape[1] + 1) - 1) + out = convolve1d(a, x[:, ::-1], mode="full") + + fg = FunctionGraph([x], [out], clone=False) + fg.attach_feature(ShapeFeature()) + InplaceElemwiseOptimizer().rewrite(fg) + assert any(node.op.destroy_map for node in fg.apply_nodes) + optdb.query("+introduce_explicit_core_shape_blockwise").rewrite(fg) + fg.toposort() From cb4b367a28f1099cfb89b1a8fd43247536c40e7c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 13 Jun 2026 17:21:46 +0200 Subject: [PATCH 10/10] Check HasShape mixin instead of hasattr in ShapeFeature The shape-bearing test was duck-typed as `hasattr(var.type, "ndim")` (carried over from the old ShapeFeature API). Use the canonical `isinstance(var.type, HasShape)` instead: every shape-bearing type (Tensor, Scalar, Sparse, XTensor) subclasses HasShape, so it is equivalent, and on the common tensor/scalar inputs it short-circuits the MRO walk rather than doing a full getattr. --- pytensor/tensor/rewriting/shape.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 982b58aea1..83dc920a15 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -13,6 +13,7 @@ copy_stack_trace, node_rewriter, ) +from pytensor.graph.type import HasShape from pytensor.tensor.basic import ( Alloc, MakeVector, @@ -64,7 +65,7 @@ def __getitem__(self, var): return result def __contains__(self, var): - return hasattr(var.type, "ndim") + return isinstance(var.type, HasShape) class ShapeFeature(Feature): @@ -111,7 +112,7 @@ def _shape_i_var(self, v, i): else: per_dim = {} self._shape_i_cache[v] = per_dim - if hasattr(v.type, "shape") and v.type.shape[i] is not None: + if isinstance(v.type, HasShape) and v.type.shape[i] is not None: res = constant(v.type.shape[i], dtype="int64") else: res = Shape_i(i)(v) @@ -126,7 +127,7 @@ def _fresh_shape_i(v, i): (e.g. its buffer destroyed by an inplace Op); a fresh read carries no constraints beyond reading ``v``, which non-recursive consumers rely on. """ - if hasattr(v.type, "shape") and v.type.shape[i] is not None: + if isinstance(v.type, HasShape) and v.type.shape[i] is not None: return constant(v.type.shape[i], dtype="int64") return Shape_i(i)(v) @@ -202,7 +203,7 @@ def _get_node_shapes(self, node, recursive=True): if ( inp_node is not None and inp_node not in cache - and hasattr(inp.type, "ndim") + and isinstance(inp.type, HasShape) # A fully-static input shape is read without touching its # owner (see ``get_shape``), so don't bother caching it. and any(s is None for s in inp.type.shape) @@ -228,7 +229,7 @@ def _compute_node_shapes(self, node, recursive): input_shapes = [] for inp in node.inputs: - if hasattr(inp.type, "ndim"): + if isinstance(inp.type, HasShape): input_shapes.append( tuple(input_shape_of(inp, j) for j in range(inp.type.ndim)) ) @@ -254,7 +255,7 @@ def _compute_node_shapes(self, node, recursive): result = [] for k, out in enumerate(node.outputs): - if not hasattr(out.type, "ndim"): + if not isinstance(out.type, HasShape): result.append(None) continue sh = None @@ -272,7 +273,7 @@ def _compute_node_shapes(self, node, recursive): def get_shape(self, var, idx): """Return a symbolic expression for ``var.shape[idx]``.""" - if hasattr(var.type, "shape") and var.type.shape[idx] is not None: + if isinstance(var.type, HasShape) and var.type.shape[idx] is not None: return constant(var.type.shape[idx], dtype="int64") node = var.owner @@ -299,7 +300,7 @@ def get_non_recursive_shape(self, var, idx): Works on an unattached feature. """ - if hasattr(var.type, "shape") and var.type.shape[idx] is not None: + if isinstance(var.type, HasShape) and var.type.shape[idx] is not None: return constant(var.type.shape[idx], dtype="int64") node = var.owner @@ -314,7 +315,7 @@ def get_non_recursive_shape(self, var, idx): return self._fresh_shape_i(var, idx) def shape_tuple(self, var): - if not hasattr(var.type, "ndim"): + if not isinstance(var.type, HasShape): return None return tuple(self.get_shape(var, i) for i in range(var.type.ndim)) @@ -378,7 +379,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): self._stale.add(node) # Schedule Shape_i(r) replacements for local_track_shape_i - if hasattr(r.type, "ndim"): + if isinstance(r.type, HasShape): for shpnode, _idx in fgraph.clients.get(r, []): if isinstance(getattr(shpnode, "op", None), Shape_i): self.scheduled[shpnode] = new_r