diff --git a/pytensor/assumptions/specify.py b/pytensor/assumptions/specify.py index ebf2282c89..201e7d6282 100644 --- a/pytensor/assumptions/specify.py +++ b/pytensor/assumptions/specify.py @@ -39,7 +39,7 @@ def make_node(self, x): out = x.type() return Apply(self, [x], [out]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback( diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index f9c74950dc..c2913cfea1 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage): def pullback(self, inputs, outputs, output_gradients): return [disconnected_type(), *output_gradients] - def infer_shape(self, fgraph, inputs, input_shapes): + def infer_shape(self, inputs, input_shapes): # Return the shape of every input but the condition (first input) return input_shapes[1:] diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 4060c38365..123f64bf0d 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from itertools import chain from typing import cast from pytensor.compile.maker import function @@ -24,26 +23,14 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError +from pytensor.tensor.shape import Shape_i def infer_shape(outs, inputs, input_shapes): - """ - Compute the shape of the outputs given the shape of the inputs of an PyTensor - graph. - - We do it this way to avoid compiling the inner function just to get - the shape. Changes to ShapeFeature could require changes in this function. - - """ - # We use a ShapeFeature because it has all the necessary logic - # inside. We don't use the full ShapeFeature interface, but we - # let it initialize itself with an empty fgraph, otherwise we will - # need to do it manually - - # TODO: ShapeFeature should live elsewhere + """Compute the shape of the outputs given the shape of the inputs of a PyTensor graph.""" from pytensor.tensor.rewriting.shape import ShapeFeature for inp, inp_shp in zip(inputs, input_shapes, strict=True): @@ -51,43 +38,36 @@ def infer_shape(outs, inputs, input_shapes): assert len(inp_shp) == inp.type.ndim shape_feature = ShapeFeature() - fgraph = FunctionGraph([], [], features=[shape_feature]) - for v in chain.from_iterable(s for s in input_shapes if s is not None): - # Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before - if (node := v.owner) is not None: - fgraph.import_node(node, import_missing=True) - - # Initialize shape_of with the input shapes - for inp, inp_shp in zip(inputs, input_shapes, strict=True): - shape_feature.set_shape(inp, inp_shp, override=True) - - def local_traverse(out): - """ - Go back in the graph, from out, adding computable shapes to shape_of. - - """ - if out in shape_feature.shape_of: - # Its shape is already known - return - elif out.owner is None: - # This is an input of the graph - shape_feature.init_r(out) - else: - # Recurse over inputs - for inp in out.owner.inputs: - if inp not in shape_feature.shape_of: - local_traverse(inp) + output_shapes = [shape_feature.shape_tuple(o) for o in outs] - # shape_feature.on_import does not actually use an fgraph - # It will call infer_shape and set_shape appropriately - dummy_fgraph = None - shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy") + # Shape expressions for root inputs are Shape_i(inp, i). + # Replace those with the caller-provided input_shapes. + replacements = {} + for inp, shp in zip(inputs, input_shapes, strict=True): + if shp is None: + continue + per_dim = shape_feature._shape_i_cache.get(inp) + if per_dim is None: + continue + for i, s in enumerate(shp): + cached = per_dim.get(i) + if cached is not None: + replacements[cached] = s + + if replacements: + flat = [s for tup in output_shapes if tup is not None for s in tup] + flat_replaced = graph_replace(flat, replacements, strict=False) + result = [] + idx = 0 + for tup in output_shapes: + if tup is None: + result.append(None) + else: + result.append(tuple(flat_replaced[idx : idx + len(tup)])) + idx += len(tup) + return result - ret = [] - for o in outs: - local_traverse(o) - ret.append(shape_feature.shape_of[o]) - return ret + return output_shapes def construct_nominal_fgraph( @@ -885,30 +865,51 @@ def connection_pattern(self, node): self._connection_pattern = ret return ret - def infer_shape(self, fgraph, node, shapes): - # TODO: Use `fgraph.shape_feature` to do this instead. - out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes) - - # Clone the output shape so that shape are computed from outer inputs. - # Note: - # Here we could do it more simply like: - # `ret = [pytensor.clone_replace(shp, replace=repl) for shp in out_shp]` - # But doing it multiple time could duplicate common subgraph between - # each shape call. PyTensor optimizer will clean this up later, but this - # will make extra work for the optimizer. - - repl = dict(zip(self.inner_inputs, node.inputs, strict=True)) - clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] - cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) + def infer_shape(self, node, shapes): + try: + template = self._inner_shape_template + frozen = self._inner_shape_frozen + except AttributeError: + from pytensor.tensor.rewriting.shape import ShapeFeature + + sf = ShapeFeature() + inner_inputs = self.inner_inputs + template = [sf.shape_tuple(o) for o in self.inner_outputs] + flat_shapes = [s for tup in template if tup is not None for s in tup] + + # Express the inner-output shapes as a frozen function of the inner + # inputs plus each input's per-dim size. from_structural_inputs rewires + # every Shape_i(inner_input, dim) occurrence to the matching input, so + # bind can later swap in the caller's shapes. One slot per input dim: + # static or unused dims become dead inputs, keeping the layout positional. + shape_inputs = [ + Shape_i(dim)(inp) + for inp in inner_inputs + for dim in range(getattr(inp.type, "ndim", 0)) + ] + frozen = FrozenFunctionGraph.from_structural_inputs( + [*inner_inputs, *shape_inputs], flat_shapes + ) + self._inner_shape_template = template + self._inner_shape_frozen = frozen + + # frozen.inputs is [*inner_inputs, *per-dim sizes]; mirror that layout. + replacements = list(node.inputs) + for shp in shapes: + if shp is not None: + replacements.extend(shp) + + bound_shapes = frozen.bind(replacements) + ret = [] - used = 0 - for i, out_shape in enumerate(out_shapes): - if out_shape is None: + idx = 0 + for tup in template: + if tup is None: ret.append(None) else: - nb = len(out_shape) - ret.append(cloned[used : used + nb]) - used += nb + nb = len(tup) + ret.append(bound_shapes[idx : idx + nb]) + idx += nb return ret diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 57d34bbf25..028c854854 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp): def make_node(self, x): return Apply(self, [x], [x.type()]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback(self, args, outputs, g_outs): @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes @@ -251,8 +251,8 @@ def __reduce__(self): ) return load_back, (mod, name) - def _infer_shape(self, fgraph, node, input_shapes): - return self.__infer_shape(fgraph, node, input_shapes) + def _infer_shape(self, node, input_shapes): + return self.__infer_shape(node, input_shapes) def as_op(itypes, otypes, infer_shape=None): @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None): It takes an optional infer_shape parameter that should be a callable with this signature: - def infer_shape(fgraph, node, input_shapes): + def infer_shape(node, input_shapes): ... return output_shapes diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index a785958f4e..b43c3cdd8b 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1031,6 +1031,37 @@ def _resolve_input(inp, memo=memo): self._variables: frozenset[Variable] | None = None self._clients: dict[Variable, list[ClientType]] | None = None + @classmethod + def from_structural_inputs( + cls, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + ) -> "FrozenFunctionGraph": + """Freeze ``outputs``, allowing ``inputs`` to be *interior* expressions. + + Structural-matching dual of `bind`: where `bind` maps inputs to values, + this lifts chosen sub-expressions up to inputs. An ``input`` produced by + an `Apply` is matched against ``outputs`` by structure (not identity) and + every occurrence is rewired to it; root inputs behave as in the + constructor. Intermediate inputs absent from ``outputs`` become dead + inputs. The signature preserves the order of ``inputs``. ``outputs`` must + be computable from ``inputs`` alone — any root they still depend on + directly must itself appear in ``inputs``, else the rewired graph is + orphaned. + """ + # Discover the true graph roots (as FunctionGraph(inputs=None) does) to + # seed the staged freeze; the caller's `inputs` may be intermediate. + # Freezing inputs and outputs together interns each intermediate input + # onto the same object as its occurrences in the outputs, so the + # re-freeze can rewire them — which requires intermediate inputs to be + # *built*, not blocked, hence only roots seed the freeze. + roots = [ + v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant) + ] + interned = cls(roots, [*inputs, *outputs]) + n_inputs = len(inputs) + return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:]) + def __reduce__(self): return FrozenFunctionGraph, (self.inputs, self.outputs) @@ -1099,7 +1130,7 @@ def bind( [o.type() for o in node.outputs], ) memo.update(zip(node.outputs, new_node.outputs)) - return [memo[out] for out in self.outputs] + return [out if isinstance(out, Constant) else memo[out] for out in self.outputs] def unfreeze(self) -> "FunctionGraph": """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index ebfa067b71..e048b4188f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -1,3 +1,4 @@ +import inspect import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Sequence @@ -120,6 +121,24 @@ class Op(MetaObject): as nodes with these Ops must be rebuilt even if the input types haven't changed. """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + method = cls.__dict__.get("infer_shape") + if method is None: + return + params = inspect.signature(method).parameters + if len(params) == 4: + warnings.warn( + f"{cls.__module__}.{cls.__qualname__}.infer_shape takes a " + "deprecated `fgraph` parameter; drop it from the signature. " + "The parameter will be passed as None.", + DeprecationWarning, + stacklevel=2, + ) + cls.infer_shape = lambda self, node, input_shapes, _old=method: _old( + self, None, node, input_shapes + ) + def make_node(self, *inputs: Variable) -> Apply: """Construct an `Apply` node that represent the application of this operation to the given inputs. diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index 26a2970408..ea8d8cfd0d 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -1,6 +1,6 @@ import copy -from collections.abc import Generator, Sequence -from typing import TYPE_CHECKING, Optional +from collections.abc import Generator, Iterable, Sequence +from typing import TYPE_CHECKING, Optional, cast import pytensor from pytensor.graph.basic import ( @@ -62,6 +62,49 @@ def rewrite_graph( return fgraph.outputs +def rewrite_subgraph( + outputs: Sequence[Variable], + frontier: Iterable[Variable], + include: Sequence[str] = ("canonicalize",), + **kwargs, +) -> list[Variable]: + """Rewrite the subgraph between ``frontier`` and ``outputs`` in isolation. + + The ``frontier`` variables are temporarily detached from their owners, so + they act as inputs of the subgraph: rewrites can neither reach past them + nor modify the graph they belong to. This allows simplifying fresh + expressions that hang off the variables of an existing `FunctionGraph` + without mutating it behind its (and its features') back. + + The rewrite is in place: ``outputs`` must not belong to a `FunctionGraph`. + + Parameters + ---------- + outputs + The outputs of the subgraph to rewrite. + frontier + Variables at which the subgraph stops; every path from ``outputs`` + into the surrounding graph must go through one of them. + include + Rewrite query names, as in `rewrite_graph`. + **kwargs + Keyword arguments passed to `rewrite_graph`. + """ + saved_owners = [(v, v.owner, v.index) for v in frontier] + for v, _, _ in saved_owners: + v.owner = None + try: + rewritten = cast( + Sequence[Variable], + rewrite_graph(list(outputs), include=include, clone=False, **kwargs), + ) + return list(rewritten) + finally: + for v, owner, idx in saved_owners: + v.owner = owner + v.index = idx + + def is_same_graph_with_merge( var1: Variable, var2: Variable, diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 3e7264fcf9..31e0feba97 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -103,7 +103,7 @@ def __str__(self): args.append("inplace") return f"if{{{','.join(args)}}}" - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): # By construction, corresponding then/else pairs have the same number # of dimensions diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index c2962e4d7d..87f52c3c4b 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props): def c_code_cache_version(self): return (2,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def do_constant_folding(self, fgraph, node): diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 8c92c819c2..a0cb073f54 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2251,7 +2251,7 @@ def perform(self, node, inputs, output_storage): self.t_call = t_call self.t_fn = t_fn - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # input_shapes correspond to the shapes of node.inputs for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 771135b1c6..a6bf82edf7 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -494,7 +494,7 @@ def pullback(self, inputs, outputs, gout): disconnected_type(), ] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # node.inputs[3] is of length as we only support sparse matrix. return [(node.inputs[3][0], node.inputs[3][1])] @@ -584,7 +584,7 @@ def perform(self, node, inputs, outputs): g_out[0] = gout_data - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -629,7 +629,7 @@ def pullback(self, inputs, outputs, outputs_gradients): else: return [Cast(inputs[0].dtype)(gz)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return ins_shapes def __str__(self): @@ -742,7 +742,7 @@ def pullback(self, inputs, outputs, gout): else: return [SparseFromDense(x.type.format)(gz)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -806,7 +806,7 @@ def pullback(self, inputs, outputs, gout): ) return (gx,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -820,7 +820,7 @@ class GetItemList(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[1][0], shapes[0][1])] def make_node(self, x, index): @@ -865,7 +865,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItemListGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, index, gz): @@ -958,7 +958,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItem2ListsGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, ind1, ind2, gz): @@ -1139,7 +1139,7 @@ class GetItemScalar(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def make_node(self, x, index): @@ -1239,7 +1239,7 @@ def pullback(self, inputs, outputs, gout): assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (transpose(gz),) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0][::-1]] @@ -1288,7 +1288,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1339,7 +1339,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1438,7 +1438,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [square_diagonal(gz)] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): return [(minimum(*shapes[0]),)] @@ -1498,7 +1498,7 @@ def perform(self, node, inputs, outputs): def pullback(self, inputs, outputs, output_grad): return [output_grad[0]] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes def __str__(self): @@ -1614,7 +1614,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[1] for shape in ins_shapes) return [(ins_shapes[0][0], d)] @@ -1711,7 +1711,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[0] for shape in ins_shapes) return [(d, ins_shapes[0][1])] @@ -1800,7 +1800,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [gz] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes @@ -1880,7 +1880,7 @@ def perform(self, node, inp, out_): (data, indices, indptr), shape=out_shape, dtype=values.dtype ) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x = node.inputs[0] return [[x[0], x[1]]] diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 107df82bfc..a65a4e3614 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -327,7 +327,7 @@ def pullback(self, inputs, outputs, gout): r = psb.SparseFromDense(o_format)(r) return [r] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): r = None if self.axis is None: r = [()] @@ -406,7 +406,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -467,7 +467,7 @@ def pullback(self, inputs, outputs, gout): derivative = {True: gz, False: None} return [derivative[b] for b in is_continuous] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -509,7 +509,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_dense_variable(gz) return psb.sp_ones_like(x) * gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -569,7 +569,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, sp_sum(gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -699,7 +699,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return y * gz, x * gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -788,7 +788,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return y * gz, psb.dense_from_sparse(x * gz) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -871,7 +871,7 @@ def pullback(self, inputs, outputs, gout): return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -989,7 +989,7 @@ def perform(self, node, inputs, outputs): self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format) ) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1034,7 +1034,7 @@ def perform(self, node, inputs, outputs): o = np.asarray(o) out[0] = o - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1284,7 +1284,7 @@ def pullback(self, inputs, outputs, gout): rval[1] = psb.dense_from_sparse(rval[1]) return rval - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1412,7 +1412,7 @@ def pullback(self, inputs, outputs, gout): (g_out,) = gout return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1596,7 +1596,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1731,7 +1731,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1829,7 +1829,7 @@ def pullback(self, inputs, outputs, gout): return rval - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[2]] @@ -1842,7 +1842,7 @@ class Dot(Op): def __str__(self): return "Sparse" + self.__class__.__name__ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes x, y = node.inputs if x.ndim == 2 and y.ndim == 2: diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 2866f1aec0..1a2c75a6a8 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -176,7 +176,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ return code - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[3]] def c_code_cache_version(self): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..48528e7696 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -637,7 +637,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(s) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -698,7 +698,7 @@ def perform(self, node, inputs, output_storage): # not using .item() because that returns a Python scalar, not a numpy scalar output_storage[0][0] = inputs[0][()] - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -1379,7 +1379,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.eye(n, m, k, dtype=self.dtype) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] @@ -1708,7 +1708,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (5,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs[1:]] def connection_pattern(self, node): @@ -1762,14 +1762,32 @@ def do_constant_folding(self, fgraph, node): if not clients: return False + from pytensor.tensor.blas import Gemv, Ger + from pytensor.tensor.blas_c import CGemv, CGer + from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + Subtensor, + ) + for client, idx in clients: client_op = client.op if isinstance(client_op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False - # Op's through which Alloc can be lifted - elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join): + # Op's through which Alloc can be lifted. ``Subtensor`` is + # included because ``local_subtensor_of_alloc`` rewrites + # ``alloc(val, *shape)[idx]`` into ``alloc(val[...], *new_shape)``, + # preserving the Alloc structure that downstream rewrites + # (e.g. ``local_blockwise_alloc_inputs``) rely on. Folding the + # Alloc here would short-circuit that lift and produce a + # broadcast-equivalent constant whose batch dim is no longer + # type-broadcastable. + elif isinstance( + client_op, Elemwise | DimShuffle | Alloc | Join | Subtensor + ): return False # Same for Blockwise, unless it has no batch_dims elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client): @@ -1779,13 +1797,13 @@ def do_constant_folding(self, fgraph, node): idx == 0 and isinstance( client_op, - pytensor.tensor.subtensor.IncSubtensor - | pytensor.tensor.subtensor.AdvancedIncSubtensor1 - | pytensor.tensor.subtensor.AdvancedIncSubtensor - | pytensor.tensor.blas.Gemv - | pytensor.tensor.blas_c.CGemv - | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer, + IncSubtensor + | AdvancedIncSubtensor1 + | AdvancedIncSubtensor + | Gemv + | CGemv + | Ger + | CGer, ) ): # Ops that will work inplace on the Alloc. So if they @@ -1966,7 +1984,7 @@ def c_code(self, node, name, inp, out_, props): """ return ret - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [(len(ishapes),)] def pullback(self, inputs, outputs, output_gradients): @@ -2254,7 +2272,7 @@ def perform(self, node, inputs, outputs_storage): for out_storage, out in zip(outputs_storage, split_outs, strict=False): out_storage[0] = out - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): axis = node.inputs[1] splits = node.inputs[2] shp_x, _shp_axis, _shp_splits = in_shapes @@ -2710,7 +2728,7 @@ def pullback(self, inputs, outputs, grads): return rval - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, ge # ishapes[0] contains the size of the axis on which we join @@ -3264,7 +3282,7 @@ def make_node(self, start, stop, step): return Apply(self, inputs, outputs) @config.change_flags(warn_float64="ignore") - def infer_shape(self, fgraph, node, i_shapes): + def infer_shape(self, node, i_shapes): from pytensor.tensor.math import ceil, maximum # Note start, stop and step can be float numbers. @@ -3641,7 +3659,7 @@ def perform(self, node, inp, out): self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): from pytensor.tensor.math import maximum shp_x = in_shapes[0] @@ -3893,7 +3911,7 @@ def pullback(self, inputs, outputs, gout): x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2)) return [x_grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): from pytensor.tensor.math import clip, minimum (in_shape,) = shapes @@ -4225,7 +4243,7 @@ def __init__(self, mode): assert mode in ("raise", "wrap", "clip") self.mode = mode - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): a_shape, choices_shape = shapes if choices_shape is None: # choices is a TypedList, not a tensor; no shape to broadcast @@ -4256,9 +4274,7 @@ def make_node(self, a, choices): choice = as_tensor_variable(choices) choice_dtype = choice.dtype - (out_shape,) = self.infer_shape( - None, None, [shape_tuple(a), shape_tuple(choice)] - ) + (out_shape,) = self.infer_shape(None, [shape_tuple(a), shape_tuple(choice)]) static_out_shape = () for s in out_shape: @@ -4361,7 +4377,7 @@ def c_code(self, node, name, inputs, out_, sub): """ return str - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs] def c_code_cache_version(self): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 0b0dcdfc2d..edb70ac0e9 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -243,7 +243,7 @@ def perform(self, node, inputs, out_storage): out += y out_storage[0][0] = np.asarray(out, dtype=y.dtype) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -316,7 +316,7 @@ def perform(self, node, inputs, output_storage): A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) output_storage[0][0] = A - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -941,7 +941,7 @@ def perform(self, node, inp, out): z += a * np.dot(x, y) zout[0] = z - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): z_shape, _, x_shape, y_shape, _ = input_shapes return [ ( @@ -1146,7 +1146,7 @@ def make_node(self, x, y): def perform(self, node, inputs, output_storage): output_storage[0][0] = np.dot(*inputs) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = """ @@ -1249,7 +1249,7 @@ def perform(self, node, inp, out): e.args = (*e.args, x.shape, y.shape) raise - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz @@ -1638,7 +1638,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index ecc4ad92d1..2c8a2c99bc 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,6 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op @@ -321,9 +320,7 @@ def make_node(self, *inputs): def batch_ndim(self, node: Apply) -> int: return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) - def infer_shape( - self, fgraph, node, input_shapes - ) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, input_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i @@ -354,13 +351,10 @@ def extract_core_shape_from_infer_shape(): return_dummy_inputs=True, propagate_unbatched_core_inputs=True, ) - dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes ] - core_output_shapes = core_op_infer_shape( - dummy_fgraph, dummy_core_node, core_input_shapes - ) + core_output_shapes = core_op_infer_shape(dummy_core_node, core_input_shapes) if not dummy_core_inputs: # All inputs are unbatched, so the core_shape can be used as is diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a2a6e2a89f..6366ec42d5 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -257,7 +257,7 @@ def perform(self, node, inp, out): new_shape.insert(augm, 1) out[0][0] = res.reshape(new_shape) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishp,) = shapes # transpose rval = [ishp[i] for i in self.shuffle] @@ -763,7 +763,7 @@ def _check_runtime_broadcast(node, inputs): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, i_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor.extra_ops import broadcast_shape out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) @@ -1434,7 +1434,7 @@ def perform(self, node, inp, out): output[0] = np.asarray(out, dtype=out_dtype) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes axis = self.axis if axis is None: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 61a8003cfd..34542e9bdb 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -150,7 +150,7 @@ def make_node(self, x, v, sorter=None): raise TypeError("sorter must be an integer vector", sorter.type) return Apply(self, [x, v, sorter], [out_type()]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] def perform(self, node, inputs, output_storage): @@ -340,7 +340,7 @@ def pullback(self, inputs, outputs, output_gradients): f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' ) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes def c_code(self, node, name, inames, onames, sub): @@ -717,7 +717,7 @@ def pullback(self, inputs, outputs, gout): return [gx, disconnected_type()] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) @@ -849,7 +849,7 @@ def perform(self, node, inputs, out_): (out,) = out_ out[0] = np.bartlett(M) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): temp = node.inputs[0] M = ptb.switch(lt(temp, 0), ptb.cast(0, temp.dtype), temp) return [[M]] @@ -892,7 +892,7 @@ class FillDiagonal(Op): # See function fill_diagonal for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val): @@ -993,7 +993,7 @@ class FillDiagonalOffset(Op): # See function fill_diagonal_offset for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val, offset): @@ -1240,7 +1240,7 @@ def perform(self, node, inputs, output_storage): else: output_storage[0][0] = outs - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): [x_shape] = i0_shapes shape0_op = Shape_i(0) out_shapes = [(shape0_op(out),) for out in node.outputs] @@ -1310,7 +1310,7 @@ def make_node(self, indices, dims): [out_type() for _i in range(ptb.get_vector_length(dims))], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] * len(node.outputs) def perform(self, node, inp, out): @@ -1387,7 +1387,7 @@ def make_node(self, *inp): [out_type()], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def perform(self, node, inp, out): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 9d35955c6f..260fc2dc09 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -100,7 +100,7 @@ def make_node(self, a, n, axis): ], ) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): shape_a = in_shapes[0] n = node.inputs[1] axis = node.inputs[2] diff --git a/pytensor/tensor/linalg/constructors.py b/pytensor/tensor/linalg/constructors.py index 96ced60f9e..fbcc6d486a 100644 --- a/pytensor/tensor/linalg/constructors.py +++ b/pytensor/tensor/linalg/constructors.py @@ -34,7 +34,7 @@ def pullback(self, inputs, outputs, gout): ] return [gout[0][slc] for slc in slices] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): first, second = unzip(shapes, n=2, strict=True) return [(pt.add(*first), pt.add(*second))] diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 0fa7a34b3b..556ead9bf2 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -33,7 +33,7 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def make_node(self, x): diff --git a/pytensor/tensor/linalg/decomposition/eigen.py b/pytensor/tensor/linalg/decomposition/eigen.py index f2afe7cf39..ba617d11f2 100644 --- a/pytensor/tensor/linalg/decomposition/eigen.py +++ b/pytensor/tensor/linalg/decomposition/eigen.py @@ -65,7 +65,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = w.astype(dtype, copy=False) outputs[1][0] = v.astype(dtype, copy=False) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shapes,) = shapes n, _ = x_shapes @@ -206,7 +206,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self return type(self)(**new_props) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n,), (n, n)] @@ -416,7 +416,7 @@ def make_node(self, a, b=None): w = vector(dtype=out_dtype, shape=(N,)) return Apply(self, inputs, [w]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [ (n,), diff --git a/pytensor/tensor/linalg/decomposition/lu.py b/pytensor/tensor/linalg/decomposition/lu.py index 2b4edb621b..db8ac60fe5 100644 --- a/pytensor/tensor/linalg/decomposition/lu.py +++ b/pytensor/tensor/linalg/decomposition/lu.py @@ -42,7 +42,7 @@ def __init__(self, *, permute_l=False, overwrite_a=False, p_indices=False): if self.overwrite_a: self.destroy_map = {0: [0]} if self.permute_l else {1: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] if self.permute_l: return [(n, n), (n, n)] @@ -258,7 +258,7 @@ def make_node(self, A): return Apply(self, [A], [LU, pivots]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n, n), (n,)] diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 9e9270259d..9d81e8eb7b 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -103,7 +103,7 @@ def make_node(self, x): return Apply(self, [x], outputs) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape diff --git a/pytensor/tensor/linalg/decomposition/schur.py b/pytensor/tensor/linalg/decomposition/schur.py index 737ca9dc1d..af3e2035ea 100644 --- a/pytensor/tensor/linalg/decomposition/schur.py +++ b/pytensor/tensor/linalg/decomposition/schur.py @@ -146,7 +146,7 @@ def perform(self, node, inputs, outputs): T_out[0] = T Z_out[0] = Z - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0], shapes[0]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": @@ -489,7 +489,7 @@ def perform(self, node, inputs, outputs): alpha_out[0] = alpha beta_out[0] = beta - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): A_shape, B_shape = shapes if self.return_eigenvalues: return [A_shape, B_shape, (A_shape[0],), (A_shape[0],), A_shape, B_shape] diff --git a/pytensor/tensor/linalg/decomposition/svd.py b/pytensor/tensor/linalg/decomposition/svd.py index 584eae1419..361b5cdd34 100644 --- a/pytensor/tensor/linalg/decomposition/svd.py +++ b/pytensor/tensor/linalg/decomposition/svd.py @@ -93,7 +93,7 @@ def perform(self, node, inputs, outputs): (s,) = outputs s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape K = ptm.minimum(M, N) diff --git a/pytensor/tensor/linalg/inverse.py b/pytensor/tensor/linalg/inverse.py index 82878c8777..6c205d6ae7 100644 --- a/pytensor/tensor/linalg/inverse.py +++ b/pytensor/tensor/linalg/inverse.py @@ -61,7 +61,7 @@ def pullback(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [list(reversed(shapes[0]))] @@ -159,7 +159,7 @@ def pushforward(self, inputs, outputs, eval_points): return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes @@ -187,7 +187,7 @@ def perform(self, node, inputs, outputs): (x,) = outputs x[0] = np.linalg.tensorinv(a, self.ind) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): sp = shapes[0][self.ind :] + shapes[0][: self.ind] return [sp] diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 42d49b10f4..542b174c4c 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -74,7 +74,7 @@ def pullback(self, inputs, outputs, output_grads): return [expm(aug)[..., :n, n:]] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] diff --git a/pytensor/tensor/linalg/solvers/core.py b/pytensor/tensor/linalg/solvers/core.py index 9805d86c8f..39a46e84a8 100644 --- a/pytensor/tensor/linalg/solvers/core.py +++ b/pytensor/tensor/linalg/solvers/core.py @@ -71,7 +71,7 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): Ashape, Bshape = shapes rows = Ashape[1] if len(Bshape) == 1: diff --git a/pytensor/tensor/linalg/solvers/linear_control.py b/pytensor/tensor/linalg/solvers/linear_control.py index 14478f8e03..2fb47fac34 100644 --- a/pytensor/tensor/linalg/solvers/linear_control.py +++ b/pytensor/tensor/linalg/solvers/linear_control.py @@ -82,7 +82,7 @@ def perform(self, node, inputs, outputs_storage): Y *= scale X[0] = Y - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[2]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": diff --git a/pytensor/tensor/linalg/summary.py b/pytensor/tensor/linalg/summary.py index bba599e17f..c76753885f 100644 --- a/pytensor/tensor/linalg/summary.py +++ b/pytensor/tensor/linalg/summary.py @@ -71,7 +71,7 @@ def pullback(self, inputs, outputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def __str__(self): @@ -106,7 +106,7 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(), ()] def __str__(self): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..7e93db8c01 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -254,7 +254,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (3,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes if self.axis is None: return [()] @@ -3106,7 +3106,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [[xshp[0], yshp[1]]] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 34479e142c..7201d631e8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -306,7 +306,7 @@ def extract_batch_shape(p, ps, n): return shape - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): _, size, *dist_params = node.inputs _, _, *param_shapes = input_shapes diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index b658cc98cb..ac01e83ef6 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -274,7 +274,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # Use shape_feature to facilitate inferring final shape. # Check that neither the RV nor the old Subtensor are in the shape graph. - output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None) + output_shape = shape_feature.shape_tuple(indexed_rv) if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)): return None diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 8d128ec698..0458929ee1 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -1,6 +1,7 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter +from pytensor.graph.traversal import applys_between from pytensor.tensor import as_tensor, constant from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.rewriting.numba import simplify_core_shape_graphs @@ -15,9 +16,10 @@ def introduce_explicit_core_shape_rv(fgraph, node): that has an extra "non-functional" input that represents the core shape of the random variable. This core_shape is used by the numba backend to pre-allocate the output array. - If available, the core shape is extracted from the shape feature of the graph, - which has a higher chance of having been simplified, optimized, constant-folded. - If missing, we fall back to the op._supp_shape_from_params method. + The core shape is built from ``ShapeFeature.get_non_recursive_shape``, whose + expressions read only the node's own inputs and can therefore be introduced + after inplacing without conflicting with the destroyers in the surrounding + graph. This rewrite is required for the numba backend implementation of RandomVariable. @@ -58,19 +60,27 @@ def introduce_explicit_core_shape_rv(fgraph, node): _next_rng, rv = node.outputs shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) - if shape_feature: - core_shape = [ - shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) - ] - else: - core_shape = op._supp_shape_from_params(op.dist_params(node)) + if shape_feature is None: + shape_feature = ShapeFeature() + + core_shape = [ + shape_feature.get_non_recursive_shape(rv, i) + for i in range(rv.type.ndim - op.ndim_supp, rv.type.ndim) + ] if len(core_shape) == 0: core_shape = constant([], dtype="int64") else: core_shape = as_tensor(core_shape) - [core_shape] = simplify_core_shape_graphs([core_shape]) + if any( + isinstance(node.op, RandomVariable) + for node in applys_between(node.inputs, [core_shape]) + ): + # If the RandomVariable shows up in the shape graph we can't introduce the core shape + return None + + [core_shape] = simplify_core_shape_graphs([core_shape], fgraph) new_outs = ( RandomVariableWithCoreShape( diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index aab187a4ef..3a88a0a39f 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -64,7 +64,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] return Apply(self, [x], [output_type]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape] = shapes joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int) return [self.output_shapes(input_shape, joined_shape)] @@ -185,7 +185,7 @@ def make_node(self, x, shape): ) return Apply(self, [x, shape], [output]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape, _] = shapes _, shape = node.inputs output_shapes = list(input_shape) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 99522237ce..48ef217c73 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -124,8 +124,8 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: def get_simplified_shape(x: TensorVariable, *, fgraph) -> tuple: """Return a simplified shape tuple for ``x``: shape_feature → static → ``x.shape``.""" try: - return fgraph.shape_feature.shape_of[x] - except (AttributeError, KeyError): + return fgraph.shape_feature.shape_tuple(x) + except AttributeError: pass static_shape = x.type.shape diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index ecb3435030..af0529869e 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -2,38 +2,19 @@ from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter -from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.rewriting.utils import rewrite_subgraph from pytensor.graph.traversal import ancestors, applys_between from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.rewriting.shape import ShapeFeature -from pytensor.tensor.shape import Shape, Shape_i -def simplify_core_shape_graphs(core_shapes): - """Simplify core shape expressions by canonicalizing shape arithmetic. - - Temporarily detaches Shape/Shape_i outputs from their owners so that - rewrite_graph operates only on the arithmetic above them (e.g. - constant-folding static shapes, eliminating dead Switch branches). - """ - shape_boundary = [ - var - for var in ancestors(core_shapes) - if var.owner is not None and isinstance(var.owner.op, (Shape, Shape_i)) - ] - saved_owners = [(v, v.owner, v.index) for v in shape_boundary] - for v, _, _ in saved_owners: - v.owner = None - try: - core_shapes = list( - rewrite_graph(core_shapes, include=("canonicalize",), clone=False) - ) - finally: - for v, owner, idx in saved_owners: - v.owner = owner - v.index = idx - return core_shapes +def simplify_core_shape_graphs(core_shapes, fgraph): + """Canonicalize the fresh shape arithmetic built by ``get_non_recursive_shape``.""" + graph_frontier = fgraph.variables.intersection( + ancestors(core_shapes, blockers=fgraph.variables) + ) + return rewrite_subgraph(core_shapes, graph_frontier, include=("canonicalize",)) @node_rewriter([Blockwise]) @@ -44,9 +25,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): that has an extra "non-functional" input that represents the core shape of the Blockwise variable. This core_shape is used by the numba backend to pre-allocate the output array. - If available, the core shape is extracted from the shape feature of the graph, - which has a higher change of having been simplified, optimized, constant-folded. - If missing, we fall back to the op._supp_shape_from_params method. + The core shape is built from ``ShapeFeature.get_non_recursive_shape``, whose + expressions read only the node's own inputs and can therefore be introduced + after inplacing without conflicting with the destroyers in the surrounding + graph. This rewrite is required for the numba backend implementation of Blockwise. @@ -98,17 +80,16 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): batch_ndim = op.batch_ndim(node) shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) - if shape_feature: - core_shapes = [ - [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] - for out in node.outputs - ] - else: - input_shapes = [tuple(inp.shape) for inp in node.inputs] - core_shapes = [ - out_shape[batch_ndim:] - for out_shape in op.infer_shape(None, node, input_shapes) + if shape_feature is None: + shape_feature = ShapeFeature() + + core_shapes = [ + [ + shape_feature.get_non_recursive_shape(out, i) + for i in range(batch_ndim, out.type.ndim) ] + for out in node.outputs + ] core_shapes = [ as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") @@ -122,7 +103,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): # If Blockwise shows up in the shape graph we can't introduce the core shape return None - core_shapes = simplify_core_shape_graphs(core_shapes) + core_shapes = simplify_core_shape_graphs(core_shapes, fgraph) return BlockwiseWithCoreShape( [*node.inputs, *core_shapes], diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 2b2060c3e7..83dc920a15 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -1,6 +1,4 @@ -import traceback -from io import StringIO -from typing import cast as type_cast +from collections import deque from warnings import warn import numpy as np @@ -15,8 +13,7 @@ copy_stack_trace, node_rewriter, ) -from pytensor.graph.traversal import ancestors -from pytensor.graph.utils import InconsistencyError, get_variable_trace_string +from pytensor.graph.type import HasShape from pytensor.tensor.basic import ( Alloc, MakeVector, @@ -30,13 +27,12 @@ stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError, ShapeError +from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, register_useless, - topo_constant_folding, ) from pytensor.tensor.shape import ( Reshape, @@ -50,616 +46,343 @@ AdvancedIncSubtensor1, IncSubtensor, Subtensor, - get_idx_list, ) -from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -class ShapeFeature(Feature): - r"""A `Feature` that tracks shape information in a graph. - - This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with - `Shape_i` and `MakeVector` `Op`\s. - - This `Feature` and its associated rewrites have several goals: - - 1. to "lift" `Shape`\s to as close to the inputs as possible, - 2. to infer the shape of every node in the graph in terms of the - input shapes, and - 3. remove fill `Op`\s (e.g. `Second`) from the graph. - - Lifting shapes as close to the inputs as possible is important for - canonicalization because it is very bad form to have to compute - something just to know how big it will be. Firstly, it is a waste - of time to compute such outputs. But it is important to get rid - of these outputs as early as possible in the compilation process - because the extra computations make it appear as if many internal - graph nodes have multiple clients. Many rewrites refuse to - work on nodes with multiple clients. - - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven rewrites. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. +class _ShapeOfProxy: + """Dict-like proxy so ``shape_feature.shape_of[var]`` keeps working.""" - Notes - ----- - To use this shape information in rewrites, use the - ``shape_of`` dictionary. + def __init__(self, feature): + self._feature = feature - For example: + def __getitem__(self, var): + result = self._feature.shape_tuple(var) + if result is None: + raise KeyError(var) + return result - .. code-block:: python + def __contains__(self, var): + return isinstance(var.type, HasShape) - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return - shape_of_output_zero = shape_of[node.output[0]] +class ShapeFeature(Feature): + r"""Lazy `Feature` that provides shape information for the graph it tracks. - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. + Shapes are derived on demand by calling each ``Op.infer_shape`` on the + current (live) node inputs, recursing toward the graph inputs. Use: - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes PyTensor - constants?? That would be confusing. + - ``get_shape(var, i)`` / ``shape_tuple(var)`` for the shape expressed in + terms of the graph inputs (recursing through the ancestors of ``var``); + - ``get_non_recursive_shape(var, i)`` for the shape expressed in terms of + ``var.owner``'s direct inputs only; + - ``same_shape(x, y)`` to statically compare two shapes. + Inferred shapes are cached per node, but the cache is invalidated whenever + an ancestor input changes (``on_change_input``), so the expressions handed + out always reference live graph variables. """ - def get_node_infer_shape(self, node): - try: - shape_infer = node.op.infer_shape - except AttributeError: - shape_infer = self.default_infer_shape - - try: - o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except ShapeError: - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except NotImplementedError as e: - raise NotImplementedError( - "Code called by infer_shape failed raising a " - "NotImplementedError. Raising NotImplementedError to " - "indicate that a shape cannot be computed is no longer " - "supported, and one should now use ShapeError " - f"instead. The original exception message is: {e}" - ).with_traceback(e.__traceback__) - except Exception as e: - msg = ( - f"Failed to infer_shape from Op {node.op}.\nInput shapes: " - f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " - f"{type(e)}\nException message: {e!s}\nTraceback: {traceback.format_exc()}" - ) - if config.on_shape_error == "raise": - raise Exception(msg).with_traceback(e.__traceback__) - else: - warn(msg) - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - - return o_shapes - - def get_shape(self, var, idx): - """Rewrites can call this to get a `Shape_i`. - - It is better to call this then use directly ``shape_of[var][idx]`` - as this method should update `shape_of` if needed. - - TODO: Up to now, we don't update it in all cases. Update in all cases. - """ - r = self.shape_of[var][idx] - if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables - ): - assert var.owner - node = var.owner - # recur on inputs - for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) - assert len(o_shapes) == len(node.outputs) - - # Only change the variables and dimensions that would introduce - # extra computation - for new_shps, out in zip(o_shapes, node.outputs, strict=True): - if not hasattr(out.type, "ndim"): - continue - - merged_shps = list(self.shape_of[out]) - changed = False - for i in range(out.type.ndim): - n_r = merged_shps[i] - if ( - n_r.owner - and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables - ): - changed = True - merged_shps[i] = new_shps[i] - if changed: - self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") + def __init__(self): + self.fgraph: FunctionGraph | None = None + # node -> tuple of (tuple of shape vars) per output, recursive shapes only, + # lazily populated. Non-recursive shapes read only the node's own inputs, so + # they are cheap and recomputed on demand rather than cached. + self._cache: dict = {} + # Nodes whose input changed since the last query. The recursive shape of a node + # embeds the shapes of its whole ancestor cone, so a change anywhere above a + # cached node invalidates it. We defer that to the next query and walk + # ``fgraph.clients`` downstream from these nodes then — the live graph already + # encodes the dependencies, so no reverse-dependency map is kept. + self._stale: set = set() + # var -> {i: Shape_i(i)(v)}, ensures Apply identity for leaves + self._shape_i_cache: dict = {} + self.lscalar_one = constant(1, dtype="int64") + # Compat: scheduled replacements for local_track_shape_i + self.scheduled: dict = {} + + def _shape_i_var(self, v, i): + per_dim = self._shape_i_cache.get(v) + if per_dim is not None: + cached = per_dim.get(i) + if cached is not None: + return cached else: - s = Shape_i(i)(r) - try: - s = get_scalar_constant_value(s) - except NotScalarConstantError: - pass - return s - - def shape_tuple(self, r): - """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): - # This happen for NoneConst. - return None - return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - - def default_infer_shape(self, fgraph, node, i_shapes): - """Return a list of shape tuple or None for the outputs of node. + per_dim = {} + self._shape_i_cache[v] = per_dim + if isinstance(v.type, HasShape) and v.type.shape[i] is not None: + res = constant(v.type.shape[i], dtype="int64") + else: + res = Shape_i(i)(v) + per_dim[i] = res + return res - This function is used for Ops that don't implement infer_shape. - Ops that do implement infer_shape should use the i_shapes parameter, - but this default implementation ignores it. + @staticmethod + def _fresh_shape_i(v, i): + """Like ``_shape_i_var``, but never reusing a cached variable. + The cached variable may live in the graph with its own constraints + (e.g. its buffer destroyed by an inplace Op); a fresh read carries no + constraints beyond reading ``v``, which non-recursive consumers rely on. """ - rval = [] - for r in node.outputs: - try: - rval.append(self.shape_tuple(r)) - except AttributeError: - rval.append(None) - return rval - - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. - - The s_i argument was produced by the infer_shape() of an Op subclass. + if isinstance(v.type, HasShape) and v.type.shape[i] is not None: + return constant(v.type.shape[i], dtype="int64") + return Shape_i(i)(v) + + def _coerce_shape_element(self, element, node): + """Validate and normalize a single shape element from infer_shape.""" + if isinstance(element, np.ndarray): + if element.ndim != 0: + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"ndarray for shape element: {element!r}" + ) + element = element.item() + if isinstance(element, Variable): + if element.type.dtype not in integer_dtypes: + raise TypeError( + f"infer_shape for {node.op} returned a non-integer " + f"Variable for shape element: {element!r}" + ) + if getattr(element.type, "ndim", 0): + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"Variable for shape element: {element!r}" + ) + if element.type.dtype != "int64": + if isinstance(element, Constant): + return constant(int(element.data), dtype="int64") + return cast(element, "int64") + return element + if isinstance(element, int | np.integer): + if int(element) < 0: + raise ValueError( + f"infer_shape for {node.op} returned a negative shape: {int(element)}" + ) + return constant(int(element), dtype="int64") + raise TypeError( + f"infer_shape for {node.op} returned an unsupported " + f"shape element of type {type(element).__name__}: {element!r}" + ) - var: the variable that correspond to s_i. This is just for - error reporting. + def _get_node_shapes(self, node, recursive=True): + """Return validated per-output shape tuples for ``node``. + With ``recursive=False``, input shapes are fresh constant/``Shape_i`` + reads of the node's own inputs instead of inferred expressions, so the + result references no variables beyond those inputs (and nothing is + cached). """ - assert s_i is not None - - if s_i == 1: - return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, np.integer | int) or ( - isinstance(s_i, np.ndarray) and s_i.ndim == 0 - ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, tuple | list): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] - if ( - s_i.owner - and isinstance(s_i.owner.op, Subtensor) - and s_i.owner.inputs[0].owner - and isinstance(s_i.owner.inputs[0].owner.op, Shape) - ): - assert s_i.type.ndim == 0 - assert len(s_i.owner.op.idx_list) == 1 - - # The current Subtensor always put constant index in the graph. - # This was not True in the past. So call the Subtensor function - # that will return the right index. - idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) - assert len(idx) == 1 - idx = idx[0] - try: - i = get_scalar_constant_value(idx) - except NotScalarConstantError: - pass - else: - # Executed only if no exception was raised - x = s_i.owner.inputs[0].owner.inputs[0] - # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] - - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) + if self._stale: + self._flush_stale() + + if not recursive: + return self._compute_node_shapes(node, recursive=False) + + cache = self._cache + cached = cache.get(node) + if cached is not None: + return cached + + # Fill the cache for ``node`` and its ancestor cone bottom-up with an + # explicit stack. Recursing through ``get_shape`` would overflow the + # Python stack on deep graphs; here a node is computed only once every + # input-producing node it needs is cached, so ``_compute_node_shapes`` + # reads its input shapes straight from the cache instead of recursing. + stack = [node] + while stack: + top = stack[-1] + if top in cache: + stack.pop() + continue + ready = True + for inp in top.inputs: + inp_node = inp.owner + if ( + inp_node is not None + and inp_node not in cache + and isinstance(inp.type, HasShape) + # A fully-static input shape is read without touching its + # owner (see ``get_shape``), so don't bother caching it. + and any(s is None for s in inp.type.shape) + ): + stack.append(inp_node) + ready = False + if ready: + stack.pop() + cache[top] = self._compute_node_shapes(top, recursive=True) - def set_shape(self, r, s, override=False): - """Assign the shape `s` to previously un-shaped variable `r`. + return cache[node] - Parameters - ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. + def _compute_node_shapes(self, node, recursive): + """Call ``infer_shape`` for ``node`` and validate the result. + Input shapes are taken from ``get_shape`` (recursive) or from fresh + ``Shape_i`` reads of the node's own inputs (non-recursive). In the + recursive case the caller must have cached the input-producing nodes + already, so ``get_shape`` resolves them from the cache without recursing. """ - if not override: - assert r not in self.shape_of, "r already in shape_of" - if s is None: - self.shape_of[r] = s - else: - if not isinstance(s, tuple | list): - raise TypeError("shapes must be tuple/list", (r, s)) - - if r.type.ndim != len(s): - sio = StringIO() - pytensor.printing.debugprint(r, file=sio, print_type=True) - raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" + input_shape_of = self.get_shape if recursive else self._fresh_shape_i + shape_i_var = self._shape_i_var if recursive else self._fresh_shape_i + + input_shapes = [] + for inp in node.inputs: + if isinstance(inp.type, HasShape): + input_shapes.append( + tuple(input_shape_of(inp, j) for j in range(inp.type.ndim)) ) + else: + input_shapes.append(None) - shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals( - get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + output_shapes = None + shape_infer = getattr(node.op, "infer_shape", None) + if shape_infer is not None: + try: + output_shapes = shape_infer(node, input_shapes) + except ShapeError: + pass + except NotImplementedError: + pass + except Exception as exc: + if config.on_shape_error == "raise": + raise + warn( + f"Failed to infer_shape from Op {node.op}: " + f"{type(exc).__name__}: {exc}" ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. + result = [] + for k, out in enumerate(node.outputs): + if not isinstance(out.type, HasShape): + result.append(None) + continue + sh = None + if output_shapes is not None and k < len(output_shapes): + sh = output_shapes[k] + if sh is None or not isinstance(sh, list | tuple): + result.append(tuple(shape_i_var(out, j) for j in range(out.type.ndim))) + continue + coerced = [] + for j, s in enumerate(sh): + coerced.append(self._coerce_shape_element(s, node)) + result.append(tuple(coerced)) - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. + return tuple(result) + def get_shape(self, var, idx): + """Return a symbolic expression for ``var.shape[idx]``.""" + if isinstance(var.type, HasShape) and var.type.shape[idx] is not None: + return constant(var.type.shape[idx], dtype="int64") + + node = var.owner + if node is None: + return self._shape_i_var(var, idx) + + node_shapes = self._get_node_shapes(node) + out_idx = node.outputs.index(var) + sh = node_shapes[out_idx] + if sh is not None: + return sh[idx] + return self._shape_i_var(var, idx) + + def get_non_recursive_shape(self, var, idx): + """Return an expression for ``var.shape[idx]`` reading only ``var.owner``'s inputs. + + Unlike ``get_shape``, input shapes are not recursively expanded: the + expression references the direct inputs of ``var.owner`` (through + constants and fresh ``Shape_i`` reads) and nothing else. It can + therefore be introduced next to any consumer of those inputs no matter + what the destroy maps in the surrounding graph are, whereas the + recursion of ``get_shape`` may surface variables that an inplace Op + destroys. + + Works on an unattached feature. """ - # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] - - # If other_shape has no information, call is pointless. - if other_shape is None: - return - - if r in self.shape_of: - r_shape = self.shape_of[r] - else: - # If no info is known on r's shape, use other_shape - self.set_shape(r, other_shape) - return - if ( - other_r.owner - and r.owner - and other_r.owner.inputs == r.owner.inputs - and other_r.owner.op == r.owner.op - ): - # We are doing a merge, so the two shape graphs will be the - # same. This is only done so that we call `ancestors` less - # frequently. - return - - # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] - for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner - and isinstance(ps.owner.op, Shape_i) - and ps.owner.op.i == i - and ps.owner.inputs[0] in (r, other_r) - ): - # If other_shape[i] is uninformative, use r_shape[i]. - # For now, we consider 2 cases of uninformative other_shape[i]: - # - Shape_i(i)(other_r); - # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif any( - ( - r_shape[i] == anc - or ( - anc.owner - and isinstance(anc.owner.op, Shape) - and anc.owner.inputs[0] == r - ) - ) - for anc in ancestors([other_shape[i]]) - ): - # Another case where we want to use r_shape[i] is when - # other_shape[i] actually depends on r_shape[i]. In that case, - # we do not want to substitute an expression with another that - # is strictly more complex. Such a substitution could also lead - # to cycles: if (in the future) r_shape[i] gets replaced by an - # expression of other_shape[i], other_shape[i] may end up - # depending on itself. - merged_shape.append(r_shape[i]) - else: - merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "shape") - or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1) - ) - or self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - get_scalar_constant_value( - merged_shape[i], - only_process_constants=True, - raise_not_constant=False, - ) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def set_shape_i(self, r, i, s_i): - """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of - prev_shape = self.shape_of[r] - # prev_shape is a tuple, so we cannot change it inplace, - # so we build another one. - new_shape = [] - for j, s_j in enumerate(prev_shape): - if j == i: - new_shape.append(self.unpack(s_i, r)) - else: - new_shape.append(s_j) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[idx] != 1 - or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals( - get_scalar_constant_value(new_shape[idx], raise_not_constant=False) - ) - for idx in range(r.type.ndim) + if isinstance(var.type, HasShape) and var.type.shape[idx] is not None: + return constant(var.type.shape[idx], dtype="int64") + + node = var.owner + if node is None: + return self._fresh_shape_i(var, idx) + + node_shapes = self._get_node_shapes(node, recursive=False) + out_idx = node.outputs.index(var) + sh = node_shapes[out_idx] + if sh is not None: + return sh[idx] + return self._fresh_shape_i(var, idx) + + def shape_tuple(self, var): + if not isinstance(var.type, HasShape): + return None + return tuple(self.get_shape(var, i) for i in range(var.type.ndim)) + + @property + def shape_of(self): + """Deprecated back-compat shim. Use ``shape_tuple(var)`` instead.""" + warn( + "ShapeFeature.shape_of is deprecated; use shape_tuple(var) instead.", + DeprecationWarning, + stacklevel=2, ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def init_r(self, r): - """Register r's shape in the shape_of dictionary.""" - if r not in self.shape_of: - self.set_shape(r, self.shape_tuple(r)) - - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + return _ShapeOfProxy(self) def on_attach(self, fgraph): if hasattr(fgraph, "shape_feature"): raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - - if hasattr(self, "fgraph") and self.fgraph != fgraph: + if self.fgraph is not None and self.fgraph is not fgraph: raise Exception("This ShapeFeature is already attached to a graph") - self.fgraph = fgraph - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") - assert self.lscalar_one.type.dtype == "int64" - - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} - - for node in fgraph.toposort(): - self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} + self._cache.clear() + self._stale.clear() + self._shape_i_cache.clear() + self.scheduled.clear() self.fgraph = None - del fgraph.shape_feature - - def on_import(self, fgraph, node, reason): - if node.outputs[0] in self.shape_of: - # this is a revert, not really an import - for r in node.outputs + node.inputs: - assert r in self.shape_of - return - - for i, r in enumerate(node.inputs): - # make sure we have shapes for the inputs - self.init_r(r) - - o_shapes = self.get_node_infer_shape(node) + if hasattr(fgraph, "shape_feature"): + del fgraph.shape_feature - # this is packed information - # an element of o_shapes is either None or a tuple - # elements of the tuple can be either strings, or ints - if len(o_shapes) != len(node.outputs): - raise Exception( - f'The infer_shape method for the Op "{node.op}" returned a list ' - f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " - f" != len(node.outputs) = {len(node.outputs)}" - ) + def _flush_stale(self): + """Drop cached shapes for the changed nodes and everything downstream of them. - # Ensure shapes are in 'int64'. This is to make sure the assert - # found in the `local_useless_subtensor` rewrite does not fail. - for sh_idx, sh in enumerate(o_shapes): - if sh is None: - continue - if not isinstance(sh, list | tuple): - raise ValueError( - f"infer_shape of {node} didn't return a list of" - f" list. It returned '{o_shapes}'" - ) - new_shape = [] - for i, d in enumerate(sh): - # Note: we ignore any shape element that is not typed (i.e., - # does not have a 'dtype' attribute). This means there may - # still remain int elements that are int32 on 32-bit platforms, - # but this works with `local_useless_subtensor`, so for now we - # keep it this way. See #266 for a better long-term fix. - if getattr(d, "dtype", "int64") != "int64": - assert d.dtype in discrete_dtypes, (node, d.dtype) - assert str(d.dtype) != "uint64", node - new_shape += sh[len(new_shape) : i + 1] - if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") - else: - casted_d = cast(d, "int64") - new_shape[i] = casted_d - if new_shape: - # We replace the shape with wrong dtype by the one with - # 'int64'. - new_shape += sh[len(new_shape) :] - o_shapes[sh_idx] = tuple(new_shape) - - for r, s in zip(node.outputs, o_shapes, strict=True): - self.set_shape(r, s) + Walks ``fgraph.clients`` from each stale node, so a change above a cached node + evicts it no matter how many levels down it sits. + """ + cache = self._cache + clients = self.fgraph.clients + queue = deque(self._stale) + seen = set(self._stale) + self._stale = set() + while queue: + node = queue.popleft() + cache.pop(node, None) + for out in node.outputs: + for client_node, _ in clients.get(out, ()): + if client_node not in seen: + seen.add(client_node) + queue.append(client_node) + + def on_prune(self, fgraph, node, reason): + self._cache.pop(node, None) + self._stale.discard(node) + for out in node.outputs: + self._shape_i_cache.pop(out, None) def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) - - # This tells us that r and new_r must have the same shape if - # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) - - # change_input happens in two cases: - # 1) we are trying to get rid of r, or - # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - # This schedule is processed by `local_track_shape_i`. - for shpnode, idx in fgraph.clients[r] + [(node, i)]: - if isinstance(shpnode.op, Shape_i): - idx = shpnode.op.i - repl = self.shape_of[new_r][idx] - if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. - continue - if ( - repl.owner - and repl.owner.inputs[0] is shpnode.inputs[0] - and isinstance(repl.owner.op, Shape_i) - and repl.owner.op.i == shpnode.op.i - ): - # The replacement is a shape_i of the same - # input. So no need to do this equivalent - # replacement. - continue + if r is new_r: + return + # Defer invalidation: the next query flushes this node and its downstream cone. + self._stale.add(node) - if shpnode.outputs[0] in ancestors([repl]): - raise InconsistencyError( - "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" - ) - - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, - # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] - for k in unscheduled: - del self.scheduled[k] - - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): - # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + # Schedule Shape_i(r) replacements for local_track_shape_i + if isinstance(r.type, HasShape): + for shpnode, _idx in fgraph.clients.get(r, []): + if isinstance(getattr(shpnode, "op", None), Shape_i): + self.scheduled[shpnode] = new_r def same_shape( self, @@ -668,63 +391,27 @@ def same_shape( dim_x: int | None = None, dim_y: int | None = None, ) -> bool: - """Return ``True`` if `x` and `y` have the same shape. - - Parameters - ========== - x - The `Variable` for which its shape is to be compared with `y`'s shape. - y - The `Variable` for which its shape is to be compared with `x`'s shape. - dim_x - If non ``None``, compare only the dimension of `x` equal to - `dim_x`. - dim_y - If non ``None``, compare only the dimension of `y` equal to - `dim_y`. - - """ - sx = self.shape_of[x] - sy = self.shape_of[y] - + """Return True if we can statically prove x and y have the same shape.""" + sx = self.shape_tuple(x) + sy = self.shape_tuple(y) if sx is None or sy is None: return False - if dim_x is not None: - sx = [sx[dim_x]] - + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] - + sy = (sy[dim_y],) if len(sx) != len(sy): return False - - # Canonicalize the graphs so that comparisons are reasonable - # TODO FIXME: This should *not* need to be performed manually here. - # Instead, the shape information in `self.shape_of` should be operated - # upon alongside all the other elements in a `FunctionGraph` (e.g. as - # if `self.shape_of.values()` were additional outputs). - shapes_fg = FunctionGraph( - outputs=sx + sy, - # features=[self], - clone=True, - # copy_inputs=False, - ) - from pytensor.graph.rewriting.utils import rewrite_graph - - canon_shapes_fg = type_cast( - FunctionGraph, - rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), - ) - canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy, strict=True): - if not equal_computations([dx], [dy]): + if dx is dy: + continue + if isinstance(dx, Constant) and isinstance(dy, Constant): + if dx.data == dy.data: + continue return False - + if equal_computations([dx], [dy]): + continue + return False return True def clone(self): @@ -1302,7 +989,7 @@ def local_shape_to_shape_i(fgraph, node): if not hasattr(fgraph, "shape_feature"): return shape_feature = fgraph.shape_feature - ret = shape_feature.make_vector_shape(node.inputs[0]) + ret = as_tensor_variable(shape_feature.shape_tuple(node.inputs[0]), dtype="int64") # We need to copy over stack trace from input to output copy_stack_trace(node.outputs[0], ret) @@ -1314,44 +1001,40 @@ def local_shape_to_shape_i(fgraph, node): @register_canonicalize @node_rewriter([Shape_i]) def local_track_shape_i(fgraph, node): - """ - Update `Shape_i` nodes to match `ShapeFeature`'s internal state. - - This rewrite is essential for propagating shape information during graph - transformations (like lowering). When a node is replaced or updated, - `ShapeFeature` calculates the shape of the new node and "schedules" - dependent `Shape_i` nodes for update, so they use the latest inferred graph. - - If we start with an fgraph containing the two nodes below: - >> out = OpWithoutInferShape(a, b) - >> out_shape_i = Shape_i(out) + """Replace ``Shape_i(v, i)`` with the inferred shape expression. - And then rewrite - >> new_out = OpWithInferShape(a, b) - >> fgraph.replace(out, new_out) - - We end up with - >> out_shape_i == Shape_i(new_out) + When ``v.owner.op`` has ``infer_shape``, ``get_shape(v, i)`` returns + a non-``Shape_i`` expression. Rewriting the literal ``Shape_i(v, i)`` + with that expression lets downstream rewrites see the inferred form + and typically lets the original producer of ``v`` be pruned when only + its shape is consumed. + """ + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is None: + return False - If installed, ShapeFeature will do this work in the background - >> new_out_shape = infer_shape(new_out) # Usually some f(a, b) - >> fgraph.shape_feature.scheduled[out_shape_i.owner] = new_out_shape + # Handle scheduled replacements from on_change_input + replacement = shape_feature.scheduled.pop(node, None) + if replacement is not None: + return [shape_feature.get_shape(replacement, node.op.i)] - And this rewrite will ultimately propagate the inference back to the fgraph - >> new_out_shape_i = fgraph.shape_feature.scheduled[out_shape_i.owner][i] - >> fgraph.replace(out_shape_i, new_out_shape_i) + [v] = node.inputs + if v.owner is None: + return False - """ - try: - shape_feature = fgraph.shape_feature - except AttributeError: + i = node.op.i + new_shape = shape_feature.get_shape(v, i) + if new_shape is None: return False - if node not in shape_feature.scheduled: + # Avoid replacing Shape_i(v, i) with itself + if new_shape.owner is node or ( + isinstance(new_shape, Variable) + and new_shape.owner is not None + and isinstance(new_shape.owner.op, Shape_i) + and new_shape.owner.op.i == i + and new_shape.owner.inputs[0] is v + ): return False - # Don't unschedule node as it could be reinserted in the - # fgraph as we don't change it in the shapefeature internal - # structure. - replacement = shape_feature.scheduled[node] - return [shape_feature.shape_of[replacement][node.op.i]] + return [new_shape] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index ef51cae306..e045648816 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -883,12 +883,12 @@ def _local_subtensor_merge_rewrite(fgraph, node, *, merge_integer_index): indices_outer = unflatten_index_variables(outer_index_vars, node.op.idx_list) try: - xshape = fgraph.shape_feature.shape_of[x] + xshape = fgraph.shape_feature.shape_tuple(x) except AttributeError: xshape = tuple(x.shape) try: - ushape = fgraph.shape_feature.shape_of[u] + ushape = fgraph.shape_feature.shape_tuple(u) except AttributeError: ushape = tuple(u.shape) @@ -1201,7 +1201,7 @@ def local_useless_subtensor(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature cdata = get_constant_idx( node.op.idx_list, @@ -1223,7 +1223,7 @@ def local_useless_subtensor(fgraph, node): # is not a useless subtensor return False - length_pos = shape_of[node.inputs[0]][pos] + length_pos = shape_feature.get_shape(node.inputs[0], pos) if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize @@ -1327,12 +1327,12 @@ def local_useless_AdvancedSubtensor1(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature # get length of the indexed tensor along the first axis try: length = get_scalar_constant_value( - shape_of[node.inputs[0]][0], only_process_constants=True + shape_feature.get_shape(node.inputs[0], 0), only_process_constants=True ) except NotScalarConstantError: return False @@ -2417,7 +2417,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): # need it for this optimization, so don't continue. return False - shape_of = shape_feature.shape_of same_shape = shape_feature.same_shape # Get the subtensor of `x` indexed by `i` in order to compare @@ -2431,22 +2430,12 @@ def local_useless_inc_subtensor_alloc(fgraph, node): else: raise Exception("Should never happen!") - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - # `xi` may have more dimensions than `y` since the subtensor ops # do automatic broadcasting of the increment internally. Thus, we # need to make the leading implicitly broadcasted dimensions # explicit for shape comparison later. if xi.ndim > y.ndim: y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") # Build `z_broad` explicitly to include extra implicit dimensions. z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable @@ -2479,7 +2468,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): if ( z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 + and shape_feature.get_shape(y, k) != 1 ) ] diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index e170c2c61b..59dc8eb7dc 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1104,12 +1104,6 @@ def local_subtensor_shape_constant(fgraph, node): TensorConstant{1} - TODO: Something like `local_shape_to_shape_i` should be a general - canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were - the case, we could change this to only operate on `Shape_i`\s. - Currently, we're not handling them because they should only appear when - `ShapeFeature` is present, and it will also simplify/remove them. - """ shape = node.inputs[0] @@ -1130,7 +1124,7 @@ def local_subtensor_shape_constant(fgraph, node): return False try: - shape_parts = shape_arg.type.broadcastable[idx_val] + shape_parts = shape_arg.type.shape[idx_val] except IndexError: # An out-of-bounds index here is an error in the source graph # (e.g. ``scalar.shape[0]``), but it should fail at runtime rather @@ -1138,10 +1132,10 @@ def local_subtensor_shape_constant(fgraph, node): return False if isinstance(shape_parts, Iterable): - if all(shape_parts): - return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] - elif shape_parts: - return [as_tensor(1, dtype=np.int64)] + if all(s is not None for s in shape_parts): + return [as_tensor(list(shape_parts), dtype=np.int64, ndim=1)] + elif shape_parts is not None: + return [as_tensor(shape_parts, dtype=np.int64)] @node_rewriter([Subtensor]) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..1b43eaec6e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -81,7 +81,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(np.shape(x), dtype="int64") - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [[len(in_shapes[0])]] def connection_pattern(self, node): @@ -297,7 +297,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [()] def connection_pattern(self, node): @@ -339,21 +339,7 @@ def shape_i(var, i, fgraph=None): """ if fgraph and hasattr(fgraph, "shape_feature"): - shape_feature = fgraph.shape_feature - shape_of = shape_feature.shape_of - - def recur(node): - if node.outputs[0] not in shape_of: - for inp in node.inputs: - if inp.owner: - recur(inp.owner) - # If the output var isn't marked as being in the graph, - # we need to add it in the ShapeFeature. - shape_feature.on_import(fgraph, node, "graph.ops.shape_i") - - if var not in shape_of: - recur(var.owner) - return shape_of[var][i] + return fgraph.shape_feature.get_shape(var, i) # If we are not able to use the shape feature, we should not put # Shape_i in the graph. Otherwise, the shape feature optimization @@ -452,7 +438,7 @@ def perform(self, node, inp, out_): ) out[0] = x - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] # Use x shape if specified dim is None, otherwise the specified shape @@ -727,7 +713,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self(eval_points[0], *inputs[1:], return_list=True) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, maximum, mul # inputs[1] can contain at most one value of '-1', meaning the actual diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index c3133e3c15..51dd796a52 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -82,7 +82,7 @@ def make_node(self, in1, in2, full_mode): out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [in1, in2, full_mode], [out]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): _, _, full_mode = node.inputs in1_shape, in2_shape, _ = shapes out_shape = [ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index af695d9e42..c911be988d 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -54,7 +54,7 @@ def perform(self, node, inputs, output_storage): z = output_storage[0] z[0] = np.sort(a, axis, self.kind) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] @@ -185,7 +185,7 @@ def perform(self, node, inputs, output_storage): dtype=node.outputs[0].dtype, ) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 403331cef6..4638fa3e70 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -922,7 +922,7 @@ def perform(self, node, inputs, out_): cdata = unflatten_index_variables(index_variables, self.idx_list) out[0] = np.asarray(x.__getitem__(tuple(cdata))) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): def _is_constant(const, x): return isinstance(const, Constant) and const.data.item() == x @@ -1767,7 +1767,7 @@ def add_to_zview(self, name, x, fail): {fail}; }}""" - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def pushforward(self, inputs, outputs, eval_points): @@ -1945,7 +1945,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] @@ -2295,7 +2295,7 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = x - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, _y, _ilist = ishapes return [x] @@ -2469,7 +2469,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): def is_bool_index(idx): return ( isinstance(idx, np.bool_ | bool) @@ -2695,7 +2695,7 @@ def perform(self, node, inputs, out_): else: np.add.at(out[0], tuple(full_indices), y) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [ishapes[0]] def connection_pattern(self, node): diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 1a7e681c22..8662faded3 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -83,11 +83,16 @@ def shape_of_variables( shape_feature = fgraph.shape_feature input_dims = [ - dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp] + dimension + for inp in fgraph.inputs + for dimension in shape_feature.shape_tuple(inp) ] output_dims = [ - dimension for shape in shape_feature.shape_of.values() for dimension in shape + dimension + for var in fgraph.variables + if hasattr(var.type, "ndim") + for dimension in shape_feature.shape_tuple(var) ] compute_shapes = pytensor.function(input_dims, output_dims) @@ -105,8 +110,10 @@ def shape_of_variables( sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) l = {} - for var in shape_feature.shape_of: - l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var]) + for var in fgraph.variables: + shape = shape_feature.shape_tuple(var) + if shape is not None: + l[var] = tuple(sym_to_num_dict[sym] for sym in shape) return l diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 3e02c75ce9..09a8d8fe1f 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -30,7 +30,7 @@ class XTypeCastOp(TypeCastingOp): This is like a `ViewOp` but without the expectation the input and output have identical types. """ - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def vectorize_node( diff --git a/tests/benchmarks/test_rewriting.py b/tests/benchmarks/test_rewriting.py index 5b0b86ba6f..7303d230de 100644 --- a/tests/benchmarks/test_rewriting.py +++ b/tests/benchmarks/test_rewriting.py @@ -1,9 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt +import pytensor.xtensor as px from pytensor import config from pytensor.graph import FunctionGraph +from pytensor.graph.rewriting import rewrite_graph +from pytensor.xtensor.shape import stack as xstack def _large_fuseable_graph(n): @@ -66,3 +70,50 @@ def rewrite_func(): assert rewrite_func() == expected_n_repl benchmark.pedantic(rewrite_func, rounds=7, iterations=5) + + +def _xtensor_attention_graph(n_layers): + B, T, E, H, HD = 4, 32, 64, 4, 16 + rng = np.random.default_rng(0) + + def attn(x): + Wqkv = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, 3, H, HD))), + dims=("embd", "qkv", "head", "hd"), + ) + Wproj = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, E))), + dims=("embd", "embd_out"), + ) + qkv = px.dot(x, Wqkv, dim="embd") + q = qkv.isel(qkv=0).rename(time="time_q") + k = qkv.isel(qkv=1).rename(time="time_k") + v = qkv.isel(qkv=2).rename(time="time_k") + s = px.dot(q, k, dim="hd") / np.sqrt(HD) + mask = px.as_xtensor( + pt.tril(pt.ones((T, T), dtype="bool")), + dims=("time_q", "time_k"), + ) + a = px.math.softmax(px.where(mask, s, np.float64(-1e9)), dim="time_k") + o = xstack(px.dot(a, v, dim="time_k"), embd=("head", "hd")) + return px.dot(o, Wproj, dim="embd").rename(time_q="time", embd_out="embd") + + x_t = pt.tensor("x", shape=(B, T, E)) + x = px.as_xtensor(x_t, dims=("batch", "time", "embd")) + for _ in range(n_layers): + x = attn(x) + return x_t, x.values.sum() + + +@pytest.mark.parametrize("n_layers", [2, 3, 4]) +def test_xtensor_attention_rewrite_benchmark(n_layers, benchmark): + x_t, loss = _xtensor_attention_graph(n_layers) + + def rewrite_once(): + lowered = rewrite_graph(loss, include=("lower_xtensor",), clone=True) + grad = pt.grad(lowered, x_t) + return rewrite_graph( + [lowered, grad], include=("fast_run",), exclude=("inplace",), clone=True + ) + + benchmark(rewrite_once) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index f8180773cc..c6f1f1e74c 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -460,8 +460,8 @@ def test_infer_shape(self): fg = FunctionGraph(outputs=[op_var[1]], clone=False) opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer()) - assert opt_res.shape_feature.shape_of[x] is None - assert opt_res.shape_feature.shape_of[z][0].data == 2 + assert opt_res.shape_feature.shape_tuple(x) is None + assert opt_res.shape_feature.shape_tuple(z)[0].data == 2 def test_make_node_shared(self): """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index a30ed6475d..1954c6cbb6 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -65,7 +65,7 @@ def test_infer_shape(self): x = dmatrix("x") y = dvector("y") - def infer_shape(fgraph, node, shapes): + def infer_shape(node, shapes): _x, y = shapes return [y] diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 2d33112373..4dad31763e 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -4,7 +4,7 @@ import pytest from pytensor.configdefaults import config -from pytensor.graph.basic import NominalVariable +from pytensor.graph.basic import NominalVariable, equal_computations from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -988,3 +988,77 @@ def test_value_dependent_output_type_collision(self): # ``bind`` reproduces each graph's own output type, not a collided one assert ffg1.bind([x, s1])[0].type.shape == (2, 3) assert ffg2.bind([x, s2])[0].type.shape == (3, 2) + + def test_bind_constant_output(self): + """bind must handle constants that appear directly as outputs.""" + x = float64("x") + c = ScalarConstant(float64, 42.0) + ffg = FunctionGraph([x], [add(x, c), c]).freeze() + + y = float64("y") + bound = ffg.bind({ffg.inputs[0]: y}) + assert len(bound) == 2 + assert bound[1] is c + + def test_from_structural_inputs_only_root_inputs(self): + """All inputs are roots: behaves like the plain constructor.""" + x, y = float64("x"), float64("y") + out = add(x, y) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y], [out]) + assert len(ffg.inputs) == 2 + + a, b = float64("a"), float64("b") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b], strict=True))) + assert equal_computations([res], [add(a, b)]) + + def test_from_structural_inputs_only_intermediate_inputs(self): + """Inputs may be only intermediate expressions; roots are found automatically.""" + x, y = float64("x"), float64("y") + # out depends on x, y only through the product. + out = add(mul(x, y), mul(x, y)) + + # The passed expression is matched by structure, not identity. + prod = mul(x, y) + assert prod is not out.owner.inputs[0] + + ffg = FrozenFunctionGraph.from_structural_inputs([prod], [out]) + assert len(ffg.inputs) == 1 + + p = float64("p") + [res] = ffg.bind({ffg.inputs[0]: p}) + # Both occurrences rewire to the single input. + assert equal_computations([res], [add(p, p)]) + + def test_from_structural_inputs_mixed_inputs(self): + """A root input and an intermediate input, both live.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, mul(x, y)], [out]) + assert len(ffg.inputs) == 2 + + a, p = float64("a"), float64("p") + # x is used directly; root y is dropped (it feeds only the lifted product). + [res] = ffg.bind(dict(zip(ffg.inputs, [a, p], strict=True))) + assert equal_computations([res], [add(p, a)]) + + def test_from_structural_inputs_dead_inputs(self): + """A dead root input and a dead intermediate input are retained but ignored.""" + x, y = float64("x"), float64("y") + out = add(x, x) # uses neither y nor the product + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y, mul(x, y)], [out]) + assert len(ffg.inputs) == 3 + + a, b, p = float64("a"), float64("b"), float64("p") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b, p], strict=True))) + assert equal_computations([res], [add(a, a)]) + + def test_from_structural_inputs_unreachable_output_raises(self): + """Outputs needing a root absent from the inputs cannot be expressed.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) # needs x directly, not only via the product + + with pytest.raises(ValueError): + FrozenFunctionGraph.from_structural_inputs([mul(x, y)], [out]) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 741d5f01a5..b3c663b5f0 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -310,7 +310,7 @@ def grad(self, inputs, gout): else: return (gz,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def test_grad_fail(self): diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..96cecdc333 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -284,7 +284,7 @@ def test_normal_ShapeFeature(): clone=False, features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt, sd_pt], [s1, s2, d_rv], mode=py_mode, on_unused_input="ignore") s1_val, s2_val, d_rv_val = f(3, np.array(1.0, dtype=config.floatX)) @@ -657,7 +657,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt], [s1, s2], mode=py_mode) s1_val, s2_val = f(2) @@ -679,7 +679,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv] + s1, s2, s3, s4 = fg.shape_feature.shape_tuple(d_rv) mean_val = np.array([[0, 1, 2]], dtype=config.floatX) f = function([mean, cov], [s1, s2, s3, s4], mode=py_mode, on_unused_input="ignore") @@ -810,7 +810,7 @@ def test_dirichlet_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) assert M_pt in graph_inputs([s1]) assert N_pt in graph_inputs([s2]) diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index 4e40507158..c511e21f66 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -9,15 +9,11 @@ from pytensor.graph.traversal import apply_ancestors from pytensor.scalar import log as scalar_log from pytensor.tensor import add, alloc, iscalar, matrix, scalar, tensor, tensor3 -from pytensor.tensor.basic import MakeVector, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.linalg.inverse import MatrixPinv -from pytensor.tensor.math import add as tensor_add -from pytensor.tensor.math import maximum, minimum from pytensor.tensor.rewriting.blockwise import local_useless_blockwise -from pytensor.tensor.shape import Reshape, Shape_i -from pytensor.tensor.signal.conv import Convolve1d +from pytensor.tensor.shape import Reshape def test_useless_blockwise_of_elemwise(): @@ -190,47 +186,3 @@ def test_blockwise_reshape(): new_y.eval({"x": test_x}, mode=no_rewrites), rewritten_y.eval({"x": test_x}, mode=no_rewrites), ) - - -@pytest.mark.parametrize( - "mode, x_shape, k_shape", - [ - ("valid", (3, 10), (3, 4)), - ("full", (3, 10), (3, 4)), - ("valid", (None, None), (None, None)), - ("full", (None, None), (None, None)), - ], -) -def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): - x = tensor("x", shape=x_shape) - k = tensor("k", shape=k_shape) - out = Blockwise(Convolve1d())( - x, k, constant(mode == "full", dtype="bool"), return_list=True - ) - fn = function([x, k], out, mode="NUMBA") - - [bwcs_node] = [ - n - for n in fn.maker.fgraph.apply_nodes - if isinstance(n.op, BlockwiseWithCoreShape) - ] - core_shape = bwcs_node.inputs[-1] - - static = all(s is not None for s in x_shape + k_shape) - if static: - n, kk = x_shape[1], k_shape[1] - expected_len = (n + kk - 1) if mode == "full" else (n - kk + 1) - expected = constant(np.array([expected_len])) - assert equal_computations([core_shape], [expected]) - else: - n = Shape_i(1)(x) - kk = Shape_i(1)(k) - if mode == "full": - expected = MakeVector("int64")( - tensor_add(constant(-1, dtype="int64"), n, kk) - ) - else: - expected = MakeVector("int64")( - constant(1, dtype="int64") + maximum(n, kk) - minimum(n, kk) - ) - assert equal_computations([core_shape], [expected], in_xs=[x, k], in_ys=[x, k]) diff --git a/tests/tensor/rewriting/test_numba.py b/tests/tensor/rewriting/test_numba.py new file mode 100644 index 0000000000..b210470ce0 --- /dev/null +++ b/tests/tensor/rewriting/test_numba.py @@ -0,0 +1,120 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import function +from pytensor.compile import optdb +from pytensor.compile.aliasing import add_supervisor_to_fgraph +from pytensor.compile.io import In +from pytensor.compile.mode import get_mode +from pytensor.graph.basic import equal_computations +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.basic import MakeVector, alloc, constant +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape +from pytensor.tensor.math import add as tensor_add +from pytensor.tensor.math import maximum, minimum +from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.signal import convolve1d +from pytensor.tensor.signal.conv import Convolve1d + + +def count_ops(fgraph, op_type): + return sum(isinstance(node.op, op_type) for node in fgraph.apply_nodes) + + +def rewrite_for_numba(inputs, outputs): + fg = FunctionGraph(inputs, outputs) + add_supervisor_to_fgraph(fg, [In(inp) for inp in fg.inputs]) + get_mode("NUMBA").optimizer.rewrite(fg) + return fg + + +def core_shape_of(fgraph): + [bwcs_node] = [ + n for n in fgraph.apply_nodes if isinstance(n.op, BlockwiseWithCoreShape) + ] + *functional_inputs, core_shape = bwcs_node.inputs + return core_shape, functional_inputs + + +@pytest.mark.parametrize( + "mode, x_shape, k_shape", + [ + ("valid", (3, 10), (3, 4)), + ("full", (3, 10), (3, 4)), + ("valid", (None, None), (None, None)), + ("full", (None, None), (None, None)), + ], +) +def test_blockwise_core_shape_simplified(mode, x_shape, k_shape): + x = pt.tensor("x", shape=x_shape) + k = pt.tensor("k", shape=k_shape) + out = Blockwise(Convolve1d())( + x, k, constant(mode == "full", dtype="bool"), return_list=True + ) + fn = function([x, k], out, mode="NUMBA") + + core_shape, _ = core_shape_of(fn.maker.fgraph) + + static = all(s is not None for s in x_shape + k_shape) + if static: + n, kk = x_shape[1], k_shape[1] + expected_len = (n + kk - 1) if mode == "full" else (n - kk + 1) + expected = constant(np.array([expected_len])) + assert equal_computations([core_shape], [expected]) + else: + n = Shape_i(1)(x) + kk = Shape_i(1)(k) + if mode == "full": + expected = MakeVector("int64")( + tensor_add(constant(-1, dtype="int64"), n, kk) + ) + else: + expected = MakeVector("int64")( + constant(1, dtype="int64") + maximum(n, kk) - minimum(n, kk) + ) + assert equal_computations([core_shape], [expected], in_xs=[x, k], in_ys=[x, k]) + + +def test_introduce_core_shape_aliasing(): + """Graphs whose shape arithmetic gets inplaced, destroying variables that + recursive core shape derivations used to read; they must simply lower. + """ + larger = pt.matrix("larger", shape=(8, None)) + smaller = pt.matrix("smaller", shape=(8, None)) + a = alloc(pt.zeros((1, 1)), 1, larger.shape[1] + smaller.shape[1] - 1) + out = convolve1d(a, larger[:, ::-1], mode="full") + + fg = rewrite_for_numba([larger, smaller], [out]) + assert count_ops(fg, BlockwiseWithCoreShape) == 1 + fg.toposort() + + # Crossed variant, where each core shape mixes dimensions of both inputs + x1 = pt.matrix("x1", shape=(8, None)) + x2 = pt.matrix("x2", shape=(8, None)) + a1 = alloc(pt.zeros((1, 1)), 1, x1.shape[1] + 3) + a2 = alloc(pt.zeros((1, 1)), 1, x2.shape[1] + 5) + convA = convolve1d(a1, x2[:, ::-1], mode="full") + convB = convolve1d(a2, x1[:, ::-1], mode="full") + + fg = rewrite_for_numba([x1, x2], [convA, convB]) + assert count_ops(fg, BlockwiseWithCoreShape) == 2 + fg.toposort() + + +def test_core_shape_simplify_keeps_fgraph_intact(): + """simplify_core_shape_graphs must not rewrite the fgraph's own applies, + like the non-canonical chain feeding the alloc dim here. + """ + x = pt.matrix("x", shape=(8, None)) + a = alloc(pt.zeros((1, 1)), 1, (x.shape[1] + 1) - 1) + out = convolve1d(a, x[:, ::-1], mode="full") + + fg = FunctionGraph([x], [out], clone=False) + fg.attach_feature(ShapeFeature()) + InplaceElemwiseOptimizer().rewrite(fg) + assert any(node.op.destroy_map for node in fg.apply_nodes) + optdb.query("+introduce_explicit_core_shape_blockwise").rewrite(fg) + fg.toposort() diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 9ab4e591f1..1b8696dd66 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -17,7 +17,7 @@ from pytensor.graph.type import Type from pytensor.tensor.basic import alloc, as_tensor_variable from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import add, exp, maximum +from pytensor.tensor.math import add, cos, exp, maximum, sin from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.shape import ( ShapeFeature, @@ -162,7 +162,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - # def infer_shape(self, fgraph, node, (xshp,)): + # def infer_shape(self, node, (xshp,)): # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] identity_noshape = IdentityNoShape() @@ -179,7 +179,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - def infer_shape(self, fgraph, node, xshp_): + def infer_shape(self, node, xshp_): # Could also just return. (xshp,) = xshp_ return (xshp,) @@ -613,6 +613,30 @@ def test_vector_dim_err(self): shape_feature.same_shape(x, o, 0, 1) +def test_get_shape_resolves_through_chain(): + """get_shape should resolve to the deepest input, not intermediate ops.""" + x = matrix("x") + w = matrix("w") + inner = cos(x) + y = sin(exp(inner.T)) + + fg = FunctionGraph([x, w], [y], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + + s = sf.get_shape(y, 0) + utt.assert_equal_computations([s], [Shape_i(1)(x)]) + + # Changing an input invalidates the cached shape of the changed node and of + # every node downstream of it, not just the node whose input changed. Here + # ``inner`` feeds the transpose, with ``exp`` and ``sin`` further downstream, + # so the re-query must resolve through ``w`` rather than returning the stale + # ``x``-based shape held by the downstream nodes' caches. + fg.replace(inner, cos(w)) + s = sf.get_shape(y, 0) + utt.assert_equal_computations([s], [Shape_i(1)(w)]) + + def test_useless_specify_shape(): x = tensor("x", shape=(None, 5, 3)) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index eb0a34f0e2..4f7503d1d7 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -204,11 +204,12 @@ def test_local_useless_inc_subtensor_no_opt(): result.assert_eval([[1, 2], [3, 4]], [[10, 20], [30, 40]]) # Increment with a non-zero constant target array, same collapse to x + y. + # ``ones`` has a static shape, so ``ones.shape`` folds to the constants (2, 2). ones = pt.ones((2, 2)) s = ones[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) result = utt.RewriteTester([y], [o_shape]) - result.assert_graph(ones + specify_shape(y, ones.shape)) + result.assert_graph(ones + specify_shape(y, (2, 2))) result.assert_eval([[10, 20], [30, 40]]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 56dd04a2b9..b245f53237 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -1060,6 +1060,17 @@ def test_local_subtensor_shape_constant(): assert isinstance(res, Constant) assert np.array_equal(res.data, [1, 1]) + # Any static dim folds, not just broadcastable ones + x = tensor(dtype=np.float64, shape=(7, None)).shape[0] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert res.data == 7 + + x = _shape(tensor(dtype=np.float64, shape=(None, 3, 7)))[1:] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert np.array_equal(res.data, [3, 7]) + @pytest.mark.parametrize( "original_fn, supported", diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index dfded3fbc3..c4023b3ea5 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -310,7 +310,7 @@ def perform(self, node, inputs, outputs): c[0] = np.arange(a.size + b.size, dtype=config.floatX) d[0] = np.arange(a.sum() + b.sum(), dtype=config.floatX) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # First output shape depends only on input_shapes # Second output shape depends on input values a_identity, b_identity = node.inputs @@ -362,7 +362,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): y = node.outputs[0] # Apparently it's valid to return integers in infer_shape. # DimShuffle does this. Modify test if that is no longer allowed. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 913a1036ff..6b72864763 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -873,7 +873,7 @@ def test_partial_static_shape_info(self): x_inferred_shape = (ps.constant(1), ps.constant(1)) res_shape = z.owner.op.infer_shape( - None, z.owner, [x_inferred_shape, x_inferred_shape] + z.owner, [x_inferred_shape, x_inferred_shape] ) assert len(res_shape) == 1 @@ -902,7 +902,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), ) in_1_shape = (ps.constant(1), ps.constant(1)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_1_shape]) for out in outs: assert out[0].eval() == 1 assert out[1].eval() == 1 @@ -911,7 +911,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) ) in_2_shape = (ps.constant(3), ps.constant(3)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_2_shape]) for out in outs: assert out[0].eval() == 3 assert out[1].eval() == 3 @@ -924,7 +924,7 @@ def test_shape_types(self): assert isinstance(z.owner.op, Elemwise) - (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) + (out_shape,) = z.owner.op.infer_shape(z.owner, [(lscalar(), 1), (50, 10)]) assert all(isinstance(v.type, TensorType) for v in out_shape) diff --git a/tests/xtensor/test_rewriting.py b/tests/xtensor/test_rewriting.py index da076b1824..2da37fa919 100644 --- a/tests/xtensor/test_rewriting.py +++ b/tests/xtensor/test_rewriting.py @@ -17,8 +17,21 @@ def test_infer_shape_db_handles_xtensor_lowering(): [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [(x.values.sum(0)).shape[0]]) - # With ShapeFeature - fgraph = FunctionGraph([x], [shape_y], features=[ShapeFeature()], copy_inputs=False) + # With ShapeFeature — force caching shape of XRV output before lowering + sf = ShapeFeature() + fgraph = FunctionGraph([x], [shape_y], features=[sf], copy_inputs=False) + # Force get_shape on the XRV sum output (y) before any rewriting lowers it. + # This caches a shape expression referencing the XRV variable. + y_in_graph = [ + v + for v in fgraph.variables + if hasattr(v.type, "ndim") and v.type.ndim == 1 and v is not x + ] + for v in y_in_graph: + try: + sf.get_shape(v, 0) + except Exception: + pass infer_shape_db.default_query.rewrite(fgraph) [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [Shape_i(1)(x)])