Skip to content
Merged
2 changes: 1 addition & 1 deletion pytensor/assumptions/specify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_node(self, x):
out = x.type()
return Apply(self, [x], [out])

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes

def pullback(
Expand Down
2 changes: 1 addition & 1 deletion pytensor/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage):
def pullback(self, inputs, outputs, output_gradients):
return [disconnected_type(), *output_gradients]

def infer_shape(self, fgraph, inputs, input_shapes):
def infer_shape(self, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
return input_shapes[1:]

Expand Down
145 changes: 73 additions & 72 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from itertools import chain
from typing import cast

from pytensor.compile.maker import function
Expand All @@ -24,70 +23,51 @@
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.shape import Shape_i


def infer_shape(outs, inputs, input_shapes):
"""
Compute the shape of the outputs given the shape of the inputs of an PyTensor
graph.

We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.

"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually

# TODO: ShapeFeature should live elsewhere
"""Compute the shape of the outputs given the shape of the inputs of a PyTensor graph."""
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes, strict=True):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)

# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp, override=True)

def local_traverse(out):
"""
Go back in the graph, from out, adding computable shapes to shape_of.

"""
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else:
# Recurse over inputs
for inp in out.owner.inputs:
if inp not in shape_feature.shape_of:
local_traverse(inp)
output_shapes = [shape_feature.shape_tuple(o) for o in outs]

# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
# Shape expressions for root inputs are Shape_i(inp, i).
# Replace those with the caller-provided input_shapes.
replacements = {}
for inp, shp in zip(inputs, input_shapes, strict=True):
if shp is None:
continue
per_dim = shape_feature._shape_i_cache.get(inp)
if per_dim is None:
continue
for i, s in enumerate(shp):
cached = per_dim.get(i)
if cached is not None:
replacements[cached] = s

if replacements:
flat = [s for tup in output_shapes if tup is not None for s in tup]
flat_replaced = graph_replace(flat, replacements, strict=False)
result = []
idx = 0
for tup in output_shapes:
if tup is None:
result.append(None)
else:
result.append(tuple(flat_replaced[idx : idx + len(tup)]))
idx += len(tup)
return result

ret = []
for o in outs:
local_traverse(o)
ret.append(shape_feature.shape_of[o])
return ret
return output_shapes


def construct_nominal_fgraph(
Expand Down Expand Up @@ -885,30 +865,51 @@ def connection_pattern(self, node):
self._connection_pattern = ret
return ret

def infer_shape(self, fgraph, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)

# Clone the output shape so that shape are computed from outer inputs.
# Note:
# Here we could do it more simply like:
# `ret = [pytensor.clone_replace(shp, replace=repl) for shp in out_shp]`
# But doing it multiple time could duplicate common subgraph between
# each shape call. PyTensor optimizer will clean this up later, but this
# will make extra work for the optimizer.

repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
def infer_shape(self, node, shapes):
try:
template = self._inner_shape_template
frozen = self._inner_shape_frozen
except AttributeError:
from pytensor.tensor.rewriting.shape import ShapeFeature

sf = ShapeFeature()
inner_inputs = self.inner_inputs
template = [sf.shape_tuple(o) for o in self.inner_outputs]
flat_shapes = [s for tup in template if tup is not None for s in tup]

# Express the inner-output shapes as a frozen function of the inner
# inputs plus each input's per-dim size. from_structural_inputs rewires
# every Shape_i(inner_input, dim) occurrence to the matching input, so
# bind can later swap in the caller's shapes. One slot per input dim:
# static or unused dims become dead inputs, keeping the layout positional.
shape_inputs = [
Shape_i(dim)(inp)
for inp in inner_inputs
for dim in range(getattr(inp.type, "ndim", 0))
]
frozen = FrozenFunctionGraph.from_structural_inputs(
[*inner_inputs, *shape_inputs], flat_shapes
)
self._inner_shape_template = template
self._inner_shape_frozen = frozen

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually we do freeze already, why not use this for equality checking above?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you meant? What equality checking above? This is freezing the shape graph, we never did this anywhere alse?


# frozen.inputs is [*inner_inputs, *per-dim sizes]; mirror that layout.
replacements = list(node.inputs)
for shp in shapes:
if shp is not None:
replacements.extend(shp)

bound_shapes = frozen.bind(replacements)

ret = []
used = 0
for i, out_shape in enumerate(out_shapes):
if out_shape is None:
idx = 0
for tup in template:
if tup is None:
ret.append(None)
else:
nb = len(out_shape)
ret.append(cloned[used : used + nb])
used += nb
nb = len(tup)
ret.append(bound_shapes[idx : idx + nb])
idx += nb

return ret

Expand Down
10 changes: 5 additions & 5 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes

def pullback(self, args, outputs, g_outs):
Expand Down Expand Up @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub):
# Else, no C code
raise NotImplementedError()

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes


Expand Down Expand Up @@ -251,8 +251,8 @@ def __reduce__(self):
)
return load_back, (mod, name)

def _infer_shape(self, fgraph, node, input_shapes):
return self.__infer_shape(fgraph, node, input_shapes)
def _infer_shape(self, node, input_shapes):
return self.__infer_shape(node, input_shapes)


def as_op(itypes, otypes, infer_shape=None):
Expand All @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None):
It takes an optional infer_shape parameter that should be a callable with
this signature:

def infer_shape(fgraph, node, input_shapes):
def infer_shape(node, input_shapes):
...
return output_shapes

Expand Down
33 changes: 32 additions & 1 deletion pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,37 @@ def _resolve_input(inp, memo=memo):
self._variables: frozenset[Variable] | None = None
self._clients: dict[Variable, list[ClientType]] | None = None

@classmethod
def from_structural_inputs(
cls,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
) -> "FrozenFunctionGraph":
"""Freeze ``outputs``, allowing ``inputs`` to be *interior* expressions.

Structural-matching dual of `bind`: where `bind` maps inputs to values,
this lifts chosen sub-expressions up to inputs. An ``input`` produced by
an `Apply` is matched against ``outputs`` by structure (not identity) and
every occurrence is rewired to it; root inputs behave as in the
constructor. Intermediate inputs absent from ``outputs`` become dead
inputs. The signature preserves the order of ``inputs``. ``outputs`` must
be computable from ``inputs`` alone — any root they still depend on
directly must itself appear in ``inputs``, else the rewired graph is
orphaned.
"""
# Discover the true graph roots (as FunctionGraph(inputs=None) does) to
# seed the staged freeze; the caller's `inputs` may be intermediate.
# Freezing inputs and outputs together interns each intermediate input
# onto the same object as its occurrences in the outputs, so the
# re-freeze can rewire them — which requires intermediate inputs to be
# *built*, not blocked, hence only roots seed the freeze.
roots = [
v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant)
]
interned = cls(roots, [*inputs, *outputs])
n_inputs = len(inputs)
return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:])

def __reduce__(self):
return FrozenFunctionGraph, (self.inputs, self.outputs)

Expand Down Expand Up @@ -1099,7 +1130,7 @@ def bind(
[o.type() for o in node.outputs],
)
memo.update(zip(node.outputs, new_node.outputs))
return [memo[out] for out in self.outputs]
return [out if isinstance(out, Constant) else memo[out] for out in self.outputs]

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
Expand Down
19 changes: 19 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -120,6 +121,24 @@ class Op(MetaObject):
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
method = cls.__dict__.get("infer_shape")
if method is None:
return
params = inspect.signature(method).parameters
if len(params) == 4:
warnings.warn(
f"{cls.__module__}.{cls.__qualname__}.infer_shape takes a "
"deprecated `fgraph` parameter; drop it from the signature. "
"The parameter will be passed as None.",
DeprecationWarning,
stacklevel=2,
)
cls.infer_shape = lambda self, node, input_shapes, _old=method: _old(
self, None, node, input_shapes
)

def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.

Expand Down
47 changes: 45 additions & 2 deletions pytensor/graph/rewriting/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Optional
from collections.abc import Generator, Iterable, Sequence
from typing import TYPE_CHECKING, Optional, cast

import pytensor
from pytensor.graph.basic import (
Expand Down Expand Up @@ -62,6 +62,49 @@ def rewrite_graph(
return fgraph.outputs


def rewrite_subgraph(
outputs: Sequence[Variable],
frontier: Iterable[Variable],
include: Sequence[str] = ("canonicalize",),
**kwargs,
) -> list[Variable]:
"""Rewrite the subgraph between ``frontier`` and ``outputs`` in isolation.

The ``frontier`` variables are temporarily detached from their owners, so
they act as inputs of the subgraph: rewrites can neither reach past them
nor modify the graph they belong to. This allows simplifying fresh
expressions that hang off the variables of an existing `FunctionGraph`
without mutating it behind its (and its features') back.

The rewrite is in place: ``outputs`` must not belong to a `FunctionGraph`.

Parameters
----------
outputs
The outputs of the subgraph to rewrite.
frontier
Variables at which the subgraph stops; every path from ``outputs``
into the surrounding graph must go through one of them.
include
Rewrite query names, as in `rewrite_graph`.
**kwargs
Keyword arguments passed to `rewrite_graph`.
"""
saved_owners = [(v, v.owner, v.index) for v in frontier]
for v, _, _ in saved_owners:
v.owner = None
try:
rewritten = cast(
Sequence[Variable],
rewrite_graph(list(outputs), include=include, clone=False, **kwargs),
)
return list(rewritten)
finally:
for v, owner, idx in saved_owners:
v.owner = owner
v.index = idx


def is_same_graph_with_merge(
var1: Variable,
var2: Variable,
Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __str__(self):
args.append("inplace")
return f"if{{{','.join(args)}}}"

def infer_shape(self, fgraph, node, inputs_shapes):
def infer_shape(self, node, inputs_shapes):
# By construction, corresponding then/else pairs have the same number
# of dimensions

Expand Down
2 changes: 1 addition & 1 deletion pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props):
def c_code_cache_version(self):
return (2,)

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]

def do_constant_folding(self, fgraph, node):
Expand Down
Loading
Loading