Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import copy
from functools import partial
from itertools import chain
from typing import Union, cast
from typing import Generic, Union, cast

from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
Expand All @@ -19,7 +19,13 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.op import (
HasInnerGraph,
Op,
OpDefaultOutputType,
OpOutputsType,
io_connection_pattern,
)
from pytensor.graph.replace import clone_replace
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.utils import MissingInputError
Expand Down Expand Up @@ -154,7 +160,7 @@ def construct_nominal_fgraph(
return fgraph, implicit_shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
class OpFromGraph(Op, HasInnerGraph, Generic[OpOutputsType, OpDefaultOutputType]):
r"""
This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to :func:`pytensor.function <pytensor.function>`
Expand Down Expand Up @@ -253,7 +259,7 @@ def rescale_dy(inps, outputs, out_grads):
def __init__(
self,
inputs: list[Variable],
outputs: list[Variable],
outputs: OpOutputsType,
*,
inline: bool = False,
lop_overrides: Union[Callable, "OpFromGraph", None] = None,
Expand Down Expand Up @@ -713,18 +719,27 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
self._rop_op_cache = wrapper
return wrapper

def L_op(self, inputs, outputs, output_grads):
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[OpDefaultOutputType],
output_grads: Sequence[OpDefaultOutputType],
) -> list[OpDefaultOutputType]:
disconnected_output_grads = tuple(
isinstance(og.type, DisconnectedType) for og in output_grads
)
lop_op = self._build_and_cache_lop_op(disconnected_output_grads)
return lop_op(*inputs, *outputs, *output_grads, return_list=True)

def R_op(self, inputs, eval_points):
def R_op(
self,
inputs: Sequence[Variable],
eval_points: OpDefaultOutputType | list[OpDefaultOutputType],
) -> list[OpDefaultOutputType]:
rop_op = self._build_and_cache_rop_op()
return rop_op(*inputs, *eval_points, return_list=True)

def __call__(self, *inputs, **kwargs):
def __call__(self, *inputs, **kwargs) -> OpOutputsType | OpOutputsType:
# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pickle
import warnings

from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.type import CType
Expand Down Expand Up @@ -221,7 +221,7 @@ def load_back(mod, name):
return obj


class FromFunctionOp(Op):
class FromFunctionOp(Op[tuple[Variable, ...], Variable]):
"""
Build a basic PyTensor Op around a function.

Expand Down
35 changes: 24 additions & 11 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True)
_TypeType = TypeVar("_TypeType", bound="Type")
_IdType = TypeVar("_IdType", bound=Hashable)
ApplyOutputsType = TypeVar("ApplyOutputsType", bound=tuple["Variable", ...])
ApplyDefaultOutputType = TypeVar("ApplyDefaultOutputType", bound="Variable")

_MOVED_FUNCTIONS = {
"walk",
Expand Down Expand Up @@ -106,7 +108,7 @@ def dprint(self, **kwargs):
return debugprint(self, **kwargs)


class Apply(Node, Generic[OpType]):
class Apply(Node, Generic[OpType, ApplyOutputsType, ApplyDefaultOutputType]):
"""A `Node` representing the application of an operation to inputs.

Basically, an `Apply` instance is an object that represents the
Expand Down Expand Up @@ -145,7 +147,7 @@ def __init__(
self,
op: OpType,
inputs: Sequence["Variable"],
outputs: Sequence["Variable"],
outputs: ApplyOutputsType,
):
if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a sequence type")
Expand All @@ -165,7 +167,8 @@ def __init__(
raise TypeError(
f"The 'inputs' argument to Apply must contain Variable instances, not {input}"
)
self.outputs: list[Variable] = []
self.outputs: ApplyOutputsType
_outputs: list[Any] = []
# filter outputs to make sure each element is a Variable
for i, output in enumerate(outputs):
if isinstance(output, Variable):
Expand All @@ -176,11 +179,17 @@ def __init__(
raise ValueError(
"All output variables passed to Apply must belong to it."
)
self.outputs.append(output)
_outputs.append(output)
else:
raise TypeError(
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)
# The _outputs will be a list of Variables and we cannot type hint each separately.
# We could use cast(ApplyOutputsType, tuple(_outputs)) to attach the type hint
# information for each output entry, but that could introduce a call overhead
# to cast.
# Instead, we will just ignore the type in this assignment
self.outputs = tuple(_outputs) # type: ignore

def __getstate__(self):
d = self.__dict__
Expand All @@ -192,7 +201,7 @@ def __getstate__(self):
d["tag"] = t
return d

def default_output(self):
def default_output(self) -> ApplyDefaultOutputType:
"""
Returns the default output for this node.

Expand All @@ -210,12 +219,12 @@ def default_output(self):
do = getattr(self.op, "default_output", None)
if do is None:
if len(self.outputs) == 1:
return self.outputs[0]
return cast(ApplyDefaultOutputType, self.outputs[0])
else:
raise ValueError(
f"Multi-output Op {self.op} default_output not specified"
)
return self.outputs[do]
return cast(ApplyDefaultOutputType, self.outputs[do])

def __str__(self):
# FIXME: The called function is too complicated for this simple use case.
Expand All @@ -224,7 +233,9 @@ def __str__(self):
def __repr__(self):
return str(self)

def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]":
def clone(
self, clone_inner_graph: bool = False
) -> "Apply[OpType, ApplyOutputsType, ApplyDefaultOutputType]":
r"""Clone this `Apply` instance.

Parameters
Expand All @@ -249,14 +260,16 @@ def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]":
new_op = new_op.clone() # type: ignore

cp = self.__class__(
new_op, self.inputs, [output.clone() for output in self.outputs]
new_op,
self.inputs,
cast(ApplyOutputsType, tuple([output.clone() for output in self.outputs])),
)
cp.tag = copy(self.tag)
return cp

def clone_with_new_inputs(
self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False
) -> "Apply[OpType]":
) -> "Apply[OpType, ApplyOutputsType, ApplyDefaultOutputType]":
r"""Duplicate this `Apply` instance in a new graph.

Parameters
Expand Down Expand Up @@ -324,7 +337,7 @@ def get_parents(self):
return list(self.inputs)

@property
def out(self):
def out(self) -> ApplyDefaultOutputType:
"""An alias for `self.default_output`"""
return self.default_output()

Expand Down
46 changes: 29 additions & 17 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
Self,
TypeVar,
cast,
)
Expand Down Expand Up @@ -48,7 +50,11 @@ def is_thunk_type(thunk: ThunkCallableType) -> ThunkType:
return res


class Op(MetaObject):
OpOutputsType = TypeVar("OpOutputsType", bound=tuple[Variable, ...])
OpDefaultOutputType = TypeVar("OpDefaultOutputType", bound=Variable)


class Op(MetaObject, Generic[OpOutputsType, OpDefaultOutputType]):
"""A class that models and constructs operations in a graph.

A `Op` instance has several responsibilities:
Expand Down Expand Up @@ -119,7 +125,9 @@ class Op(MetaObject):
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""

def make_node(self, *inputs: Variable) -> Apply:
def make_node(
self, *inputs: Variable
) -> Apply[Self, OpOutputsType, OpDefaultOutputType]:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.

This must be implemented by sub-classes.
Expand Down Expand Up @@ -159,11 +167,13 @@ def make_node(self, *inputs: Variable) -> Apply:
if inp != out
)
)
return Apply(self, inputs, [o() for o in self.otypes])
return Apply(
self, inputs, cast(OpOutputsType, tuple([o() for o in self.otypes]))
)

def __call__(
self, *inputs: Any, name=None, return_list=False, **kwargs
) -> Variable | list[Variable]:
self, *inputs: Any, name=None, return_list: bool = False, **kwargs
) -> OpOutputsType | OpDefaultOutputType | tuple[OpDefaultOutputType]:
r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.

This method is just a wrapper around :meth:`Op.make_node`.
Expand Down Expand Up @@ -218,15 +228,15 @@ def __call__(
if self.default_output is not None:
rval = node.outputs[self.default_output]
if return_list:
return [rval]
return rval
return cast(tuple[OpDefaultOutputType], (rval,))
return cast(OpDefaultOutputType, rval)
else:
if return_list:
return list(node.outputs)
return cast(OpOutputsType, tuple(node.outputs))
elif len(node.outputs) == 1:
return node.outputs[0]
return cast(OpDefaultOutputType, node.outputs[0])
else:
return node.outputs
return cast(OpOutputsType, tuple(node.outputs))

def __ne__(self, other: Any) -> bool:
return not (self == other)
Expand All @@ -236,8 +246,8 @@ def __ne__(self, other: Any) -> bool:
add_tag_trace = staticmethod(add_tag_trace)

def grad(
self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
) -> list[Variable]:
self, inputs: Sequence[Variable], output_grads: Sequence[OpDefaultOutputType]
) -> list[OpDefaultOutputType]:
r"""Construct a graph for the gradient with respect to each input variable.

Each returned `Variable` represents the gradient with respect to that
Expand Down Expand Up @@ -283,9 +293,9 @@ def grad(
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
outputs: Sequence[OpDefaultOutputType],
output_grads: Sequence[OpDefaultOutputType],
) -> list[OpDefaultOutputType]:
r"""Construct a graph for the L-operator.

The L-operator computes a row vector times the Jacobian.
Expand All @@ -310,8 +320,10 @@ def L_op(
return self.grad(inputs, output_grads)

def R_op(
self, inputs: list[Variable], eval_points: Variable | list[Variable]
) -> list[Variable]:
self,
inputs: list[Variable],
eval_points: OpDefaultOutputType | list[OpDefaultOutputType],
) -> list[OpDefaultOutputType]:
r"""Construct a graph for the R-operator.

This method is primarily used by `Rop`.
Expand Down
17 changes: 12 additions & 5 deletions pytensor/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
from collections.abc import Callable, Collection, Iterable
from pathlib import Path
from re import Pattern
from typing import TYPE_CHECKING, Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, cast

import numpy as np

from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType
from pytensor.graph.op import (
ComputeMapType,
Op,
OpDefaultOutputType,
OpOutputsType,
StorageMapType,
ThunkType,
)
from pytensor.graph.type import HasDataType
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.interface import CLinkerOp
Expand All @@ -32,7 +39,7 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType:
return res


class COp(Op, CLinkerOp):
class COp(Op, CLinkerOp, Generic[OpOutputsType, OpDefaultOutputType]):
"""An `Op` with a C implementation."""

def make_c_thunk(
Expand Down Expand Up @@ -133,7 +140,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
)


class OpenMPOp(COp):
class OpenMPOp(COp, Generic[OpOutputsType, OpDefaultOutputType]):
r"""Base class for `Op`\s using OpenMP.

This `Op` will check that the compiler support correctly OpenMP code.
Expand Down Expand Up @@ -254,7 +261,7 @@ def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]:
return define_all, undef_all


class ExternalCOp(COp):
class ExternalCOp(COp, Generic[OpOutputsType, OpDefaultOutputType]):
"""Class for an `Op` with an external C implementation.

One can inherit from this class, provide its constructor with a path to
Expand Down
3 changes: 2 additions & 1 deletion pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.link.c.type import Generic
from pytensor.scalar.basic import ScalarType, as_scalar
from pytensor.tensor.type import DenseTensorType
from pytensor.tensor.variable import TensorVariable


class ExceptionType(Generic):
Expand All @@ -23,7 +24,7 @@ def __hash__(self):
exception_type = ExceptionType()


class CheckAndRaise(COp):
class CheckAndRaise(COp[tuple[TensorVariable], TensorVariable]):
"""An `Op` that checks conditions and raises an exception if they fail.

This `Op` returns its "value" argument if its condition arguments are all
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ def _cast_to_promised_scalar_dtype(x, dtype):
return getattr(np, dtype)(x)


class ScalarOp(COp):
class ScalarOp(COp[tuple[ScalarVariable], ScalarVariable]):
nin = -1
nout = 1

Expand Down
Loading
Loading