From 093b244764214aa0d566d1d7ea6a3a11091e868e Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Tue, 24 Mar 2026 11:12:55 +0100 Subject: [PATCH 1/4] Make Op class an inmutable generic --- pytensor/compile/builders.py | 23 +++++++---- pytensor/graph/basic.py | 17 ++++---- pytensor/graph/op.py | 27 +++++++----- pytensor/link/c/op.py | 16 +++++--- pytensor/raise_op.py | 3 +- pytensor/scalar/basic.py | 2 +- pytensor/scan/op.py | 2 +- pytensor/scan/rewriting.py | 7 +++- pytensor/sparse/basic.py | 41 ++++++++++--------- pytensor/sparse/math.py | 35 ++++++++-------- pytensor/sparse/rewriting.py | 5 ++- .../tensor/_linalg/solve/linear_control.py | 2 +- pytensor/tensor/_linalg/solve/tridiagonal.py | 4 +- pytensor/tensor/basic.py | 28 ++++++------- pytensor/tensor/blas.py | 9 ++-- pytensor/tensor/blas_c.py | 3 +- pytensor/tensor/blockwise.py | 2 +- pytensor/tensor/einsum.py | 2 +- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/extra_ops.py | 20 ++++----- pytensor/tensor/fft.py | 5 ++- pytensor/tensor/fourier.py | 4 +- pytensor/tensor/math.py | 4 +- pytensor/tensor/nlinalg.py | 28 ++++++------- pytensor/tensor/optimize.py | 2 +- pytensor/tensor/pad.py | 2 +- pytensor/tensor/random/op.py | 2 +- pytensor/tensor/reshape.py | 4 +- pytensor/tensor/shape.py | 8 ++-- pytensor/tensor/signal/conv.py | 4 +- pytensor/tensor/slinalg.py | 36 ++++++++-------- pytensor/tensor/sort.py | 5 ++- pytensor/tensor/special.py | 7 ++-- pytensor/tensor/subtensor.py | 12 +++--- pytensor/typed_list/basic.py | 21 +++++----- pytensor/xtensor/basic.py | 4 +- 36 files changed, 213 insertions(+), 185 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index e36023127f..9f64b45e26 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -5,7 +5,7 @@ from copy import copy from functools import partial from itertools import chain -from typing import Union, cast +from typing import Generic, Union, cast from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared @@ -19,7 +19,7 @@ ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType -from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern +from pytensor.graph.op import HasInnerGraph, Op, OpOutputType, io_connection_pattern from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError @@ -154,7 +154,7 @@ def construct_nominal_fgraph( return fgraph, implicit_shared_inputs, update_d, update_expr -class OpFromGraph(Op, HasInnerGraph): +class OpFromGraph(Op, HasInnerGraph, Generic[OpOutputType]): r""" This creates an `Op` from inputs and outputs lists of variables. The signature is similar to :func:`pytensor.function ` @@ -253,7 +253,7 @@ def rescale_dy(inps, outputs, out_grads): def __init__( self, inputs: list[Variable], - outputs: list[Variable], + outputs: list[OpOutputType], *, inline: bool = False, lop_overrides: Union[Callable, "OpFromGraph", None] = None, @@ -713,18 +713,27 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]: self._rop_op_cache = wrapper return wrapper - def L_op(self, inputs, outputs, output_grads): + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[OpOutputType], + output_grads: Sequence[OpOutputType], + ) -> list[OpOutputType]: disconnected_output_grads = tuple( isinstance(og.type, DisconnectedType) for og in output_grads ) lop_op = self._build_and_cache_lop_op(disconnected_output_grads) return lop_op(*inputs, *outputs, *output_grads, return_list=True) - def R_op(self, inputs, eval_points): + def R_op( + self, + inputs: Sequence[Variable], + eval_points: OpOutputType | list[OpOutputType], + ) -> list[OpOutputType]: rop_op = self._build_and_cache_rop_op() return rop_op(*inputs, *eval_points, return_list=True) - def __call__(self, *inputs, **kwargs): + def __call__(self, *inputs, **kwargs) -> OpOutputType | list[OpOutputType]: # The user interface doesn't expect the shared variable inputs of the # inner-graph, but, since `Op.make_node` does (and `Op.__call__` # dispatches to `Op.make_node`), we need to compensate here diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 6614f969d4..5d1d9f72be 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -37,6 +37,7 @@ OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True) _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) +ApplyOutType = TypeVar("ApplyOutType", bound="Variable") _MOVED_FUNCTIONS = { "walk", @@ -106,7 +107,7 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) -class Apply(Node, Generic[OpType]): +class Apply(Node, Generic[OpType, ApplyOutType]): """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -145,7 +146,7 @@ def __init__( self, op: OpType, inputs: Sequence["Variable"], - outputs: Sequence["Variable"], + outputs: Sequence[ApplyOutType], ): if not isinstance(inputs, Sequence): raise TypeError("The inputs of an Apply must be a sequence type") @@ -165,7 +166,7 @@ def __init__( raise TypeError( f"The 'inputs' argument to Apply must contain Variable instances, not {input}" ) - self.outputs: list[Variable] = [] + self.outputs: list[ApplyOutType] = [] # filter outputs to make sure each element is a Variable for i, output in enumerate(outputs): if isinstance(output, Variable): @@ -192,7 +193,7 @@ def __getstate__(self): d["tag"] = t return d - def default_output(self): + def default_output(self) -> ApplyOutType: """ Returns the default output for this node. @@ -215,7 +216,7 @@ def default_output(self): raise ValueError( f"Multi-output Op {self.op} default_output not specified" ) - return self.outputs[do] + return cast(ApplyOutType, self.outputs[do]) def __str__(self): # FIXME: The called function is too complicated for this simple use case. @@ -224,7 +225,7 @@ def __str__(self): def __repr__(self): return str(self) - def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]": + def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType, ApplyOutType]": r"""Clone this `Apply` instance. Parameters @@ -256,7 +257,7 @@ def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]": def clone_with_new_inputs( self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False - ) -> "Apply[OpType]": + ) -> "Apply[OpType, ApplyOutType]": r"""Duplicate this `Apply` instance in a new graph. Parameters @@ -324,7 +325,7 @@ def get_parents(self): return list(self.inputs) @property - def out(self): + def out(self) -> ApplyOutType: """An alias for `self.default_output`""" return self.default_output() diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 20f28e76fd..354506674f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -3,7 +3,9 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Protocol, + Self, TypeVar, cast, ) @@ -48,7 +50,10 @@ def is_thunk_type(thunk: ThunkCallableType) -> ThunkType: return res -class Op(MetaObject): +OpOutputType = TypeVar("OpOutputType", bound=Variable) + + +class Op(MetaObject, Generic[OpOutputType]): """A class that models and constructs operations in a graph. A `Op` instance has several responsibilities: @@ -119,7 +124,7 @@ class Op(MetaObject): as nodes with these Ops must be rebuilt even if the input types haven't changed. """ - def make_node(self, *inputs: Variable) -> Apply: + def make_node(self, *inputs: Variable) -> Apply[Self, OpOutputType]: """Construct an `Apply` node that represent the application of this operation to the given inputs. This must be implemented by sub-classes. @@ -159,11 +164,11 @@ def make_node(self, *inputs: Variable) -> Apply: if inp != out ) ) - return Apply(self, inputs, [o() for o in self.otypes]) + return Apply(self, inputs, [cast(OpOutputType, o()) for o in self.otypes]) def __call__( self, *inputs: Any, name=None, return_list=False, **kwargs - ) -> Variable | list[Variable]: + ) -> OpOutputType | list[OpOutputType]: r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. This method is just a wrapper around :meth:`Op.make_node`. @@ -236,8 +241,8 @@ def __ne__(self, other: Any) -> bool: add_tag_trace = staticmethod(add_tag_trace) def grad( - self, inputs: Sequence[Variable], output_grads: Sequence[Variable] - ) -> list[Variable]: + self, inputs: Sequence[Variable], output_grads: Sequence[OpOutputType] + ) -> list[OpOutputType]: r"""Construct a graph for the gradient with respect to each input variable. Each returned `Variable` represents the gradient with respect to that @@ -283,9 +288,9 @@ def grad( def L_op( self, inputs: Sequence[Variable], - outputs: Sequence[Variable], - output_grads: Sequence[Variable], - ) -> list[Variable]: + outputs: Sequence[OpOutputType], + output_grads: Sequence[OpOutputType], + ) -> list[OpOutputType]: r"""Construct a graph for the L-operator. The L-operator computes a row vector times the Jacobian. @@ -310,8 +315,8 @@ def L_op( return self.grad(inputs, output_grads) def R_op( - self, inputs: list[Variable], eval_points: Variable | list[Variable] - ) -> list[Variable]: + self, inputs: list[Variable], eval_points: OpOutputType | list[OpOutputType] + ) -> list[OpOutputType]: r"""Construct a graph for the R-operator. This method is primarily used by `Rop`. diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 2a0170f98d..71374bd777 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -4,13 +4,19 @@ from collections.abc import Callable, Collection, Iterable from pathlib import Path from re import Pattern -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, cast import numpy as np from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable -from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType +from pytensor.graph.op import ( + ComputeMapType, + Op, + OpOutputType, + StorageMapType, + ThunkType, +) from pytensor.graph.type import HasDataType from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.interface import CLinkerOp @@ -32,7 +38,7 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType: return res -class COp(Op, CLinkerOp): +class COp(Op, CLinkerOp, Generic[OpOutputType]): """An `Op` with a C implementation.""" def make_c_thunk( @@ -133,7 +139,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): ) -class OpenMPOp(COp): +class OpenMPOp(COp, Generic[OpOutputType]): r"""Base class for `Op`\s using OpenMP. This `Op` will check that the compiler support correctly OpenMP code. @@ -254,7 +260,7 @@ def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]: return define_all, undef_all -class ExternalCOp(COp): +class ExternalCOp(COp, Generic[OpOutputType]): """Class for an `Op` with an external C implementation. One can inherit from this class, provide its constructor with a path to diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 2f357148a1..01765b2742 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -10,6 +10,7 @@ from pytensor.link.c.type import Generic from pytensor.scalar.basic import ScalarType, as_scalar from pytensor.tensor.type import DenseTensorType +from pytensor.tensor.variable import TensorVariable class ExceptionType(Generic): @@ -23,7 +24,7 @@ def __hash__(self): exception_type = ExceptionType() -class CheckAndRaise(COp): +class CheckAndRaise(COp[TensorVariable]): """An `Op` that checks conditions and raises an exception if they fail. This `Op` returns its "value" argument if its condition arguments are all diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index c7c4c2e7f8..7431171d3e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1182,7 +1182,7 @@ def _cast_to_promised_scalar_dtype(x, dtype): return getattr(np, dtype)(x) -class ScalarOp(COp): +class ScalarOp(COp[ScalarVariable]): nin = -1 nout = 1 diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 7353e3b889..60837ae4e3 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -710,7 +710,7 @@ def validate_inner_graph(self): ) -class Scan(Op, ScanMethodsMixin, HasInnerGraph): +class Scan(Op[Variable], ScanMethodsMixin, HasInnerGraph): r"""An `Op` implementing `for` and `while` loops. This `Op` has an "inner-graph" that represents the steps performed during diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index ebd1636ed3..bb1f3a40f9 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -946,7 +946,10 @@ def add_requirements(self, fgraph): fgraph.attach_feature(DestroyHandler()) def attempt_scan_inplace( - self, fgraph: FunctionGraph, node: Apply[Scan], output_indices: list[int] + self, + fgraph: FunctionGraph, + node: Apply[Scan, Variable], + output_indices: list[int], ) -> Apply | None: """Attempt to replace a `Scan` node by one which computes the specified outputs inplace. @@ -1012,7 +1015,7 @@ def attempt_scan_inplace( k: v for k, v in new_op.view_map.items() if k not in destroy_map } - new_node: Apply = new_op.make_node(*inputs) + new_node: Apply[Scan, Variable] = new_op.make_node(*inputs) try: fgraph.replace_all_validate_remove( # type: ignore diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 54da0b43d4..25bddbf89c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -30,6 +30,7 @@ from pytensor.tensor.type import TensorType, ivector, scalar, tensor, vector from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes +from pytensor.tensor.variable import TensorVariable sparse_formats = ["csc", "csr"] @@ -244,7 +245,7 @@ def bsr_matrix(name=None, dtype=None): discrete_dtypes = int_dtypes + uint_dtypes -class CSMProperties(Op): +class CSMProperties(Op[TensorVariable]): """Create arrays containing all the properties of a given sparse matrix. More specifically, this `Op` extracts the ``.data``, ``.indices``, @@ -360,7 +361,7 @@ def csm_shape(csm): return csm_properties(csm)[3] -class CSM(Op): +class CSM(Op[TensorVariable]): """Construct a CSM matrix from constituent parts. Notes @@ -504,7 +505,7 @@ def infer_shape(self, fgraph, node, shapes): CSR = CSM("csr") -class CSMGrad(Op): +class CSMGrad(Op[TensorVariable]): """Compute the gradient of a CSM. Note @@ -591,7 +592,7 @@ def infer_shape(self, fgraph, node, shapes): csm_grad = CSMGrad -class Cast(Op): +class Cast(Op[TensorVariable]): __props__ = ("out_type",) def __init__(self, out_type): @@ -669,7 +670,7 @@ def cast(variable, dtype): return Cast(dtype)(variable) -class DenseFromSparse(Op): +class DenseFromSparse(Op[TensorVariable]): """Convert a sparse matrix to a dense one. Notes @@ -749,7 +750,7 @@ def infer_shape(self, fgraph, node, shapes): dense_from_sparse = DenseFromSparse() -class SparseFromDense(Op): +class SparseFromDense(Op[TensorVariable]): """Convert a dense matrix to a sparse matrix.""" __props__ = () @@ -815,7 +816,7 @@ def infer_shape(self, fgraph, node, shapes): csc_from_dense = SparseFromDense("csc") -class GetItemList(Op): +class GetItemList(Op[TensorVariable]): """Select row of sparse matrix, returning them as a new sparse matrix.""" __props__ = () @@ -862,7 +863,7 @@ def grad(self, inputs, g_outputs): get_item_list = GetItemList() -class GetItemListGrad(Op): +class GetItemListGrad(Op[TensorVariable]): __props__ = () def infer_shape(self, fgraph, node, shapes): @@ -905,7 +906,7 @@ def perform(self, node, inp, outputs): get_item_list_grad = GetItemListGrad() -class GetItem2Lists(Op): +class GetItem2Lists(Op[TensorVariable]): """Select elements of sparse matrix, returning them in a vector.""" __props__ = () @@ -955,7 +956,7 @@ def grad(self, inputs, g_outputs): get_item_2lists = GetItem2Lists() -class GetItem2ListsGrad(Op): +class GetItem2ListsGrad(Op[TensorVariable]): __props__ = () def infer_shape(self, fgraph, node, shapes): @@ -996,7 +997,7 @@ def perform(self, node, inp, outputs): get_item_2lists_grad = GetItem2ListsGrad() -class GetItem2d(Op): +class GetItem2d(Op[TensorVariable]): """Implement a subtensor of sparse variable, returning a sparse matrix. If you want to take only one element of a sparse matrix see @@ -1125,7 +1126,7 @@ def perform(self, node, inputs, outputs): get_item_2d = GetItem2d() -class GetItemScalar(Op): +class GetItemScalar(Op[TensorVariable]): """Subtensor of a sparse variable that takes two scalars as index and returns a scalar. If you want to take a slice of a sparse matrix see `GetItem2d` that returns a @@ -1186,7 +1187,7 @@ def perform(self, node, inputs, outputs): get_item_scalar = GetItemScalar() -class Transpose(Op): +class Transpose(Op[TensorVariable]): """Transpose of a sparse matrix. Notes @@ -1246,7 +1247,7 @@ def infer_shape(self, fgraph, node, shapes): transpose = Transpose() -class ColScaleCSC(Op): +class ColScaleCSC(Op[TensorVariable]): # Scale each columns of a sparse matrix by the corresponding # element of a dense vector @@ -1292,7 +1293,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -class RowScaleCSC(Op): +class RowScaleCSC(Op[TensorVariable]): # Scale each row of a sparse matrix by the corresponding element of # a dense vector @@ -1400,7 +1401,7 @@ def row_scale(x, s): return col_scale(x.T, s).T -class Diag(Op): +class Diag(Op[TensorVariable]): """Extract the diagonal of a square sparse matrix as a dense vector. Notes @@ -1454,7 +1455,7 @@ def square_diagonal(diag): return CSC(data, indices, indptr, ptb.as_tensor((n, n))) -class EnsureSortedIndices(Op): +class EnsureSortedIndices(Op[TensorVariable]): """Re-sort indices of a sparse matrix. CSR column indices are not necessarily sorted. Likewise @@ -1539,7 +1540,7 @@ def clean(x): return ensure_sorted_indices(remove0(x)) -class Stack(Op): +class Stack(Op[TensorVariable]): __props__ = ("format", "dtype") def __init__(self, format=None, dtype=None): @@ -1750,7 +1751,7 @@ def vstack(blocks, format=None, dtype=None): return VStack(format=format, dtype=dtype)(*blocks) -class Remove0(Op): +class Remove0(Op[TensorVariable]): """Remove explicit zeros from a sparse matrix. Notes @@ -1807,7 +1808,7 @@ def infer_shape(self, fgraph, node, i0_shapes): remove0 = Remove0() -class ConstructSparseFromList(Op): +class ConstructSparseFromList(Op[TensorVariable]): """Constructs a sparse matrix out of a list of 2-D matrix rows. Notes diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index f561469b02..fa942b56e6 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -15,6 +15,7 @@ from pytensor.sparse.type import SparseTensorType from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor +from pytensor.tensor.variable import TensorVariable def structured_elemwise(tensor_op): @@ -254,7 +255,7 @@ def conjugate(x): structured_conjugate = conj = conjugate -class SpSum(Op): +class SpSum(Op[TensorVariable]): """ WARNING: judgement call... @@ -374,7 +375,7 @@ def sp_sum(x, axis=None, sparse_grad=False): return SpSum(axis, sparse_grad)(x) -class AddSS(Op): +class AddSS(Op[TensorVariable]): # add(sparse, sparse). # see the doc of add() for more detail. __props__ = () @@ -411,7 +412,7 @@ def infer_shape(self, fgraph, node, shapes): add_s_s = AddSS() -class AddSSData(Op): +class AddSSData(Op[TensorVariable]): """Add two sparse matrices assuming they have the same sparsity pattern. Notes @@ -472,7 +473,7 @@ def infer_shape(self, fgraph, node, ins_shapes): add_s_s_data = AddSSData() -class AddSD(Op): +class AddSD(Op[TensorVariable]): # add(sparse, sparse). # see the doc of add() for more detail. __props__ = () @@ -514,7 +515,7 @@ def infer_shape(self, fgraph, node, shapes): add_s_d = AddSD() -class StructuredAddSV(Op): +class StructuredAddSV(Op[TensorVariable]): """Structured addition of a sparse matrix and a dense vector. The elements of the vector are only added to the corresponding @@ -666,7 +667,7 @@ def sub(x, y): sub.__doc__ = subtract.__doc__ -class SparseSparseMultiply(Op): +class SparseSparseMultiply(Op[TensorVariable]): # mul(sparse, sparse) # See the doc of mul() for more detail __props__ = () @@ -704,7 +705,7 @@ def infer_shape(self, fgraph, node, shapes): mul_s_s = SparseSparseMultiply() -class SparseDenseMultiply(Op): +class SparseDenseMultiply(Op[TensorVariable]): # mul(sparse, dense) # See the doc of mul() for more detail __props__ = () @@ -793,7 +794,7 @@ def infer_shape(self, fgraph, node, shapes): mul_s_d = SparseDenseMultiply() -class SparseDenseVectorMultiply(Op): +class SparseDenseVectorMultiply(Op[TensorVariable]): """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. Notes @@ -941,7 +942,7 @@ def mul(x, y): mul.__doc__ = multiply.__doc__ -class __ComparisonOpSS(Op): +class __ComparisonOpSS(Op[TensorVariable]): """ Used as a superclass for all comparisons between two sparses matrices. @@ -991,7 +992,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -class __ComparisonOpSD(Op): +class __ComparisonOpSD(Op[TensorVariable]): """ Used as a superclass for all comparisons between sparse and dense matrix. @@ -1195,7 +1196,7 @@ def comparison(self, x, y): ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d) -class TrueDot(Op): +class TrueDot(Op[TensorVariable]): # TODO # Simplify code by splitting into DotSS and DotSD. @@ -1335,7 +1336,7 @@ def true_dot(x, y, grad_preserves_dense=True): return psb.transpose(TrueDot(grad_preserves_dense)(y.T, x.T)) -class StructuredDot(Op): +class StructuredDot(Op[TensorVariable]): __props__ = () def make_node(self, a, b): @@ -1466,7 +1467,7 @@ def structured_dot(x, y): return _structured_dot(y.T, x.T).T -class StructuredDotGradCSC(COp): +class StructuredDotGradCSC(COp[TensorVariable]): # Op that produces the grad of StructuredDot. # :param a_indices: Matrix indices @@ -1601,7 +1602,7 @@ def infer_shape(self, fgraph, node, shapes): sdg_csc = StructuredDotGradCSC() -class StructuredDotGradCSR(COp): +class StructuredDotGradCSR(COp[TensorVariable]): # Op that produces the grad of StructuredDot. # :param a_indices: Matrix indices @@ -1758,7 +1759,7 @@ def structured_dot_grad(sparse_A, dense_B, ga): raise NotImplementedError() -class SamplingDot(Op): +class SamplingDot(Op[TensorVariable]): """Compute the dot product ``dot(x, y.T) = z`` for only a subset of `z`. This is equivalent to ``p * (x . y.T)`` where ``*`` is the element-wise @@ -1834,7 +1835,7 @@ def infer_shape(self, fgraph, node, ins_shapes): sampling_dot = SamplingDot() -class Dot(Op): +class Dot(Op[TensorVariable]): __props__ = () def __str__(self): @@ -1985,7 +1986,7 @@ def dot(x, y): return _dot(x, y) -class Usmm(Op): +class Usmm(Op[TensorVariable]): """Computes the dense matrix resulting from ``alpha * x @ y + z``. Notes diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index d992635298..d9ffed5f3e 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -21,6 +21,7 @@ from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.shape import shape, specify_shape from pytensor.tensor.type import TensorType, tensor +from pytensor.tensor.variable import TensorVariable _is_sparse_variable = sparse._is_sparse_variable @@ -241,7 +242,7 @@ def local_addsd_ccode(fgraph, node): ) -class StructuredDotCSC(COp): +class StructuredDotCSC(COp[TensorVariable]): """ Structured Dot CSC is like `dot`, except that only the gradient wrt non-zero elements of a sparse matrix are calculated and propagated. @@ -438,7 +439,7 @@ def c_code_cache_version(self): sd_csc = StructuredDotCSC() -class StructuredDotCSR(COp): +class StructuredDotCSR(COp[TensorVariable]): """ Structured Dot CSR is like dot, except that only the gradient wrt non-zero elements of a sparse matrix diff --git a/pytensor/tensor/_linalg/solve/linear_control.py b/pytensor/tensor/_linalg/solve/linear_control.py index 861ba6b9b8..f797a1d90f 100644 --- a/pytensor/tensor/_linalg/solve/linear_control.py +++ b/pytensor/tensor/_linalg/solve/linear_control.py @@ -20,7 +20,7 @@ from pytensor.tensor.variable import TensorVariable -class TRSYL(Op): +class TRSYL(Op[TensorVariable]): """ Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation: diff --git a/pytensor/tensor/_linalg/solve/tridiagonal.py b/pytensor/tensor/_linalg/solve/tridiagonal.py index a97d8eaf68..d0cfe43073 100644 --- a/pytensor/tensor/_linalg/solve/tridiagonal.py +++ b/pytensor/tensor/_linalg/solve/tridiagonal.py @@ -15,7 +15,7 @@ from pytensor.tensor import TensorLike -class LUFactorTridiagonal(Op): +class LUFactorTridiagonal(Op[TensorVariable]): """Compute LU factorization of a tridiagonal matrix (lapack gttrf)""" __props__ = ( @@ -89,7 +89,7 @@ def perform(self, node, inputs, output_storage): output_storage[4][0] = ipiv -class SolveLUFactorTridiagonal(Op): +class SolveLUFactorTridiagonal(Op[TensorVariable]): """Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs).""" __props__ = ("b_ndim", "overwrite_b", "transposed") diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e54240e429..fd245ddb30 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -616,7 +616,7 @@ def get_scalar_constant_value( ) -class TensorFromScalar(COp): +class TensorFromScalar(COp[TensorVariable]): __props__ = () def make_node(self, s): @@ -673,7 +673,7 @@ def vectorize_tensor_from_scalar(op, node, batch_x): return identity(batch_x).owner -class ScalarFromTensor(COp): +class ScalarFromTensor(COp[ScalarVariable]): __props__ = () def __call__(self, *args, **kwargs) -> ScalarVariable: @@ -949,7 +949,7 @@ def ones(shape, dtype=None) -> TensorVariable: return alloc(np.array(1, dtype=dtype), *shape) -class Nonzero(Op): +class Nonzero(Op[TensorVariable]): """ Return the indices of the elements that are non-zero. @@ -1340,7 +1340,7 @@ def triu_indices_from( return triu_indices(a.shape[0], k=k, m=a.shape[1]) -class Eye(Op): +class Eye(Op[TensorVariable]): _output_type_depends_on_input_value = True __props__ = ("dtype",) @@ -1534,7 +1534,7 @@ def check_type(s): return sh, static_shape -class Alloc(COp): +class Alloc(COp[TensorVariable]): """Create a `TensorVariable` from an initial value and a desired shape. Usage: @@ -1879,7 +1879,7 @@ def full_like( return fill(a, fill_value) -class MakeVector(COp): +class MakeVector(COp[TensorVariable]): """Concatenate a number of scalars together into a vector. This is a simple version of stack() that introduces far less cruft @@ -2060,7 +2060,7 @@ def register_transfer(fn): identity = tensor_copy -class Default(Op): +class Default(Op[TensorVariable]): """ Takes an input x and a default value. @@ -2191,7 +2191,7 @@ def split(x, splits_size, *, n_splits=None, axis=0): return Split(n_splits)(x, axis, splits_size) -class Split(COp): +class Split(COp[TensorVariable]): """Partition a `TensorVariable` along some axis. Examples @@ -2428,7 +2428,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ -class Join(COp): +class Join(COp[TensorVariable]): r""" Concatenate multiple `TensorVariable`\s along some axis. @@ -3249,7 +3249,7 @@ def tile( return A_replicated.reshape(tiled_shape) -class ARange(COp): +class ARange(COp[TensorVariable]): """Create an array containing evenly spaced values within a given interval. Parameters and behaviour are the same as numpy.arange(). @@ -3536,7 +3536,7 @@ def __getitem__(self, *args): ogrid = _nd_grid(sparse=True) -class PermuteRowElements(Op): +class PermuteRowElements(Op[TensorVariable]): """Permute the elements of each row (inner-most dim) of a tensor. A permutation will be applied to every row (vector) of the input tensor x. @@ -3746,7 +3746,7 @@ def inverse_permutation(perm): ) -class ExtractDiag(COp): +class ExtractDiag(COp[TensorVariable]): """ Return specified diagonals. @@ -4261,7 +4261,7 @@ def choose(a, choices, mode="raise"): return Choose(mode)(a, choices) -class Choose(Op): +class Choose(Op[TensorVariable]): __props__ = ("mode",) def __init__(self, mode): @@ -4326,7 +4326,7 @@ def perform(self, node, inputs, outputs): z[0] = np.choose(a, choice, mode=self.mode) -class AllocEmpty(COp): +class AllocEmpty(COp[TensorVariable]): """Implement Alloc on the cpu, but without initializing memory.""" _output_type_depends_on_input_value = True diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index acfb69fe8e..daa5cdd16f 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -108,6 +108,7 @@ from pytensor.tensor.math import dot, tensordot from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import DenseTensorType, tensor +from pytensor.tensor.variable import TensorVariable _logger = logging.getLogger("pytensor.tensor.blas") @@ -150,7 +151,7 @@ def must_initialize_y_gemv(): must_initialize_y_gemv._result = None # type: ignore -class Gemv(Op): +class Gemv(Op[TensorVariable]): """ expression is beta * y + alpha * A x @@ -256,7 +257,7 @@ def infer_shape(self, fgraph, node, input_shapes): gemv = gemv_no_inplace -class Ger(Op): +class Ger(Op[TensorVariable]): """ BLAS defines general rank-1 update GER as A <- A + alpha x y' @@ -468,7 +469,7 @@ def _ldflags( return rval -class GemmRelated(COp): +class GemmRelated(COp[TensorVariable]): """Base class for Gemm and Dot22. This class provides a kind of templated gemm Op. @@ -1304,7 +1305,7 @@ def c_code_cache_version(self): _dot22scalar = Dot22Scalar() -class BatchedDot(COp): +class BatchedDot(COp[TensorVariable]): """ Computes a batch matrix-matrix dot with tensor3 variables diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 83cd87796a..548bc45623 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -8,9 +8,10 @@ blas_header_version, ldflags, ) +from pytensor.tensor.variable import TensorVariable -class BaseBLAS(COp): +class BaseBLAS(COp[TensorVariable]): def c_libraries(self, **kwargs): return ldflags() diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index aa5aacbb7b..84ac7d6543 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -151,7 +151,7 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ ) -class Blockwise(COp): +class Blockwise(COp[TensorVariable]): """Generalizes a core `Op` to work with batched dimensions. TODO: C implementation? diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index a6d5a358f1..88ef39042b 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -35,7 +35,7 @@ CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str] -class Einsum(OpFromGraph): +class Einsum(OpFromGraph[TensorVariable]): """ Wrapper Op for Einsum graphs diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index ddb7376fce..aa72145e62 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1134,7 +1134,7 @@ def outer(self, x, y): return self(x_, y_) -class CAReduce(COp): +class CAReduce(COp[TensorVariable]): """Reduces a scalar operation along specified axes. The scalar op should be both commutative and associative. diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 4191d678cd..31c95e44a8 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -50,7 +50,7 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH -class CpuContiguous(COp): +class CpuContiguous(COp[TensorVariable]): """ Check to see if the input is c-contiguous. @@ -109,7 +109,7 @@ def c_code_cache_version(self): cpu_contiguous = CpuContiguous() -class SearchsortedOp(COp): +class SearchsortedOp(COp[TensorVariable]): """Wrapper for ``numpy.searchsorted``. For full documentation, see :func:`searchsorted`. @@ -284,7 +284,7 @@ def searchsorted(x, v, side="left", sorter=None): return SearchsortedOp(side=side)(x, v, sorter) -class CumOp(COp): +class CumOp(COp[TensorVariable]): # See function cumsum/cumprod for docstring __props__ = ("axis", "mode") @@ -620,7 +620,7 @@ def compress(condition, x, axis=None): return _x.take(indices, axis=axis) -class Repeat(Op): +class Repeat(Op[TensorVariable]): # See the repeat function for docstring __props__ = ("axis",) @@ -831,7 +831,7 @@ def repeat( return broadcast_a.reshape(repeat_shape) -class Bartlett(Op): +class Bartlett(Op[TensorVariable]): # See function bartlett for docstring __props__ = () @@ -888,7 +888,7 @@ def bartlett(M): return bartlett_(M) -class FillDiagonal(Op): +class FillDiagonal(Op[TensorVariable]): # See function fill_diagonal for docstring __props__ = () @@ -989,7 +989,7 @@ def fill_diagonal(a, val): return fill_diagonal_(a, val) -class FillDiagonalOffset(Op): +class FillDiagonalOffset(Op[TensorVariable]): # See function fill_diagonal_offset for docstring __props__ = () @@ -1161,7 +1161,7 @@ def to_one_hot(y, nb_class, dtype=None): return ret -class Unique(Op): +class Unique(Op[TensorVariable]): """ Wraps `numpy.unique`. @@ -1283,7 +1283,7 @@ def unique( return Unique(return_index, return_inverse, return_counts, axis)(ar) -class UnravelIndex(Op): +class UnravelIndex(Op[TensorVariable]): __props__ = ("order",) def __init__(self, order="C"): @@ -1360,7 +1360,7 @@ def unravel_index(indices, dims, order="C"): return tuple(res) -class RavelMultiIndex(Op): +class RavelMultiIndex(Op[TensorVariable]): __props__ = ("mode", "order") def __init__(self, mode="raise", order="C"): diff --git a/pytensor/tensor/fft.py b/pytensor/tensor/fft.py index 89742248a6..fa18312ff8 100644 --- a/pytensor/tensor/fft.py +++ b/pytensor/tensor/fft.py @@ -7,9 +7,10 @@ from pytensor.tensor.math import sqrt from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.type import TensorType, integer_dtypes +from pytensor.tensor.variable import TensorVariable -class RFFTOp(Op): +class RFFTOp(Op[TensorVariable]): __props__ = () def output_type(self, inp): @@ -69,7 +70,7 @@ def connection_pattern(self, node): rfft_op = RFFTOp() -class IRFFTOp(Op): +class IRFFTOp(Op[TensorVariable]): __props__ = () def output_type(self, inp): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 033d46222c..34c185ee89 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -16,10 +16,10 @@ from pytensor.tensor.shape import shape from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.type import TensorType, integer_dtypes -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable -class Fourier(Op): +class Fourier(Op[TensorVariable]): """ WARNING: for officially supported FFTs, use pytensor.tensor.fft, which provides real-input FFTs. Gradients are supported. diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 422ce324e0..1baa312f01 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -139,7 +139,7 @@ def _allclose(a, b, rtol=None, atol=None): return np.allclose(a, b, atol=atol_, rtol=rtol_) -class Argmax(COp): +class Argmax(COp[TensorVariable]): """ Calculate the argmax over a given axis or over all axes. """ @@ -3019,7 +3019,7 @@ def clip(x, min, max): pprint.assign(pow, printing.OperatorPrinter("**", 1, "right")) -class Dot(Op): +class Dot(Op[TensorVariable]): """ Computes the dot product of two matrices variables diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 65c8a256f2..7267fa9afb 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1,5 +1,5 @@ import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Literal, cast @@ -26,9 +26,10 @@ tensor, vector, ) +from pytensor.tensor.variable import TensorVariable -class MatrixPinv(Op): +class MatrixPinv(Op[TensorVariable]): __props__ = ("hermitian",) gufunc_signature = "(m,n)->(n,m)" @@ -208,7 +209,7 @@ def trace(X): return diagonal(X).sum() -class Det(Op): +class Det(Op[TensorVariable]): """ Matrix determinant. Input should be a square matrix. @@ -259,7 +260,7 @@ def __str__(self): det = Blockwise(Det()) -class SLogDet(Op): +class SLogDet(Op[TensorVariable]): """ Compute the log determinant and its sign of the matrix. Input should be a square matrix. """ @@ -323,7 +324,7 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) -class Eig(Op): +class Eig(Op[TensorVariable]): """ Compute the eigenvalues and right eigenvectors of a square array. """ @@ -465,7 +466,7 @@ def _zero_disconnected(outputs, grads): return l -class EighGrad(Op): +class EighGrad(Op[TensorVariable]): """ Gradient of an eigensystem of a Hermitian matrix. @@ -535,7 +536,7 @@ def eigh(a, UPLO="L"): return Eigh(UPLO)(a) -class SVD(Op): +class SVD(Op[TensorVariable]): """ Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V @@ -607,12 +608,7 @@ def infer_shape(self, fgraph, node, shapes): else: return [s_shape] - def L_op( - self, - inputs: Sequence[Variable], - outputs: Sequence[Variable], - output_grads: Sequence[Variable], - ) -> list[Variable]: + def L_op(self, inputs, outputs, output_grads): """ Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here: https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194 @@ -746,7 +742,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): return Blockwise(SVD(full_matrices, compute_uv))(a) -class Lstsq(Op): +class Lstsq(Op[TensorVariable]): __props__ = () def make_node(self, x, y, rcond): @@ -1017,7 +1013,7 @@ def norm( ) -class TensorInv(Op): +class TensorInv(Op[TensorVariable]): """ Class wrapper for tensorinv() function; PyTensor utilization of numpy.linalg.tensorinv; @@ -1075,7 +1071,7 @@ def tensorinv(a, ind=2): return TensorInv(ind)(a) -class TensorSolve(Op): +class TensorSolve(Op[TensorVariable]): """ PyTensor utilization of numpy.linalg.tensorsolve Class wrapper for tensorsolve function. diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index da8064b370..4760dc2bc7 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -163,7 +163,7 @@ def _depends_only_on_constants(var: Variable) -> bool: ] -class ScipyWrapperOp(Op, HasInnerGraph): +class ScipyWrapperOp(Op[TensorVariable], HasInnerGraph): """Shared logic for scipy optimization ops""" def build_fn(self): diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py index efe7da88dc..441e21c882 100644 --- a/pytensor/tensor/pad.py +++ b/pytensor/tensor/pad.py @@ -413,7 +413,7 @@ def _reflect_inner(i, x, x_flipped, padding_left): return x -class Pad(OpFromGraph): +class Pad(OpFromGraph[TensorVariable]): """ Wrapper Op for Pad graphs """ diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 02f1840521..c5f3d60ff4 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -33,7 +33,7 @@ from pytensor.tensor.variable import TensorVariable -class RNGConsumerOp(Op): +class RNGConsumerOp(Op[TensorVariable]): """Baseclass for Ops that consume RNGs.""" @abc.abstractmethod diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index 98024f28ca..0ee6446083 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -23,7 +23,7 @@ ) -class JoinDims(Op): +class JoinDims(Op[TensorVariable]): __props__ = ("start_axis", "n_axes") view_map = {0: [0]} @@ -154,7 +154,7 @@ def join_dims( return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value] -class SplitDims(Op): +class SplitDims(Op[TensorVariable]): __props__ = ("axis",) view_map = {0: [0]} diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 7c8120981f..e30841c2c8 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -50,7 +50,7 @@ def register_shape_c_code(type, code, version=()): Shape.c_code_and_version[type] = (code, version) -class Shape(COp): +class Shape(COp[TensorVariable]): """ L{Op} to return the shape of a matrix. @@ -198,7 +198,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: return res -class Shape_i(COp): +class Shape_i(COp[TensorVariable]): """ L{Op} to return the shape of a matrix. @@ -380,7 +380,7 @@ def register_shape_i_c_code(typ, code, check_input, version=()): Shape_i.c_code_and_version[typ] = (code, check_input, version) -class SpecifyShape(COp): +class SpecifyShape(COp[TensorVariable]): """ L{Op} that puts into the graph the user-provided shape. @@ -625,7 +625,7 @@ def _vectorize_specify_shape(op, node, x, *shape): return specify_shape(x, new_shape).owner -class Reshape(COp): +class Reshape(COp[TensorVariable]): """Perform a reshape operation of the input x to the new shape shp. The number of dimensions to which to reshape to (ndim) must be known at graph build time. diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 3014823edf..062f0cdb47 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -113,7 +113,7 @@ def L_op(self, inputs, outputs, output_grads): ] -class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc] +class Convolve1d(AbstractConvolveNd, COp[TensorVariable]): # type: ignore[misc] __props__ = () ndim = 1 @@ -246,7 +246,7 @@ def convolve1d( return type_cast(TensorVariable, _blockwise_convolve_1d(in1, in2, full_mode)) -class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc] +class Convolve2d(AbstractConvolveNd, Op[TensorVariable]): # type: ignore[misc] __props__ = ("method",) # type: ignore[assignment] ndim = 2 diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 1695ade729..da40ac99db 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class Cholesky(Op): +class Cholesky(Op[TensorVariable]): # TODO: LAPACK wrapper with in-place behavior, for solve also __props__ = ("lower", "overwrite_a") @@ -210,7 +210,7 @@ def cholesky( return res -class SolveBase(Op): +class SolveBase(Op[TensorVariable]): """Base class for `scipy.linalg` matrix equation solvers.""" __props__: tuple[str, ...] = ( @@ -412,7 +412,7 @@ def cho_solve( return Blockwise(CholeskySolve(lower=lower, b_ndim=b_ndim))(A, b) -class LU(Op): +class LU(Op[TensorVariable]): """Decompose a matrix into lower and upper triangular matrices.""" __props__ = ("permute_l", "overwrite_a", "p_indices") @@ -504,10 +504,10 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": def L_op( self, - inputs: Sequence[ptb.Variable], - outputs: Sequence[ptb.Variable], - output_grads: Sequence[ptb.Variable], - ) -> list[ptb.Variable]: + inputs, + outputs, + output_grads, + ): r""" Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization F. R. De Hoog, R.S. Anderssen, M. A. Lukas @@ -521,9 +521,7 @@ def L_op( # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient # We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass - P_or_indices, L, U = lu( # type: ignore - A, permute_l=False, p_indices=False - ) + P_or_indices, L, U = lu(A, permute_l=False, p_indices=False) else: # In both other cases, there are 3 outputs. The first output will either be the permutation index itself, @@ -603,7 +601,7 @@ def lu( ) -class PivotToPermutations(Op): +class PivotToPermutations(Op[TensorVariable]): gufunc_signature = "(x)->(x)" __props__ = ("inverse",) @@ -636,7 +634,7 @@ def pivot_to_permutation(p: TensorLike, inverse=False): return PivotToPermutations(inverse=inverse)(p) -class LUFactor(Op): +class LUFactor(Op[TensorVariable]): __props__ = ("overwrite_a",) gufunc_signature = "(m,m)->(m,m),(m)" @@ -1129,7 +1127,7 @@ def solve( )(a, b) -class Eigvalsh(Op): +class Eigvalsh(Op[TensorVariable]): """ Generalized eigenvalues of a Hermitian positive definite eigensystem. @@ -1176,7 +1174,7 @@ def infer_shape(self, fgraph, node, shapes): return [(n,)] -class EigvalshGrad(Op): +class EigvalshGrad(Op[TensorVariable]): """ Gradient of generalized eigenvalues of a Hermitian positive definite eigensystem. @@ -1236,7 +1234,7 @@ def eigvalsh(a, b, lower=True): return Eigvalsh(lower)(a, b) -class Expm(Op): +class Expm(Op[TensorVariable]): """ Compute the matrix exponential of a square array. """ @@ -1299,7 +1297,7 @@ def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) -class BaseBlockDiagonal(Op): +class BaseBlockDiagonal(Op[TensorVariable]): __props__: tuple[str, ...] = ("n_inputs",) def __init__(self, n_inputs): @@ -1410,7 +1408,7 @@ def block_diag(*matrices: TensorVariable): return _block_diagonal_matrix(*matrices) -class QR(Op): +class QR(Op[TensorVariable]): """ QR Decomposition """ @@ -1780,7 +1778,7 @@ def qr( return Blockwise(QR(mode=mode, pivoting=pivoting, overwrite_a=False))(A) -class Schur(Op): +class Schur(Op[TensorVariable]): """ Schur Decomposition """ @@ -1967,7 +1965,7 @@ def schur( return Blockwise(Schur(output=output, sort=sort))(A) # type: ignore[return-value] -class QZ(Op): +class QZ(Op[TensorVariable]): """ QZ Decomposition """ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index 92b48011f0..7ccbe48490 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -8,6 +8,7 @@ from pytensor.tensor.basic import arange, as_tensor_variable, switch from pytensor.tensor.math import eq, ge from pytensor.tensor.type import TensorType +from pytensor.tensor.variable import TensorVariable KIND = typing.Literal["quicksort", "mergesort", "heapsort", "stable"] @@ -28,7 +29,7 @@ def _parse_sort_args(kind: KIND | None, order, stable: bool | None) -> KIND: return kind -class SortOp(Op): +class SortOp(Op[TensorVariable]): """ This class is a wrapper for numpy sort function. @@ -153,7 +154,7 @@ def sort( return SortOp(kind)(a, axis) -class ArgSortOp(Op): +class ArgSortOp(Op[TensorVariable]): """ This class is a wrapper for numpy argsort function. diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index b845e69b37..9a8ded48da 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -9,9 +9,10 @@ from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, log, neg, sum +from pytensor.tensor.variable import TensorVariable -class SoftmaxGrad(COp): +class SoftmaxGrad(COp[TensorVariable]): """ Gradient wrt x of the Softmax Op. @@ -239,7 +240,7 @@ def c_code(self, node, name, inp, out, sub): ) -class Softmax(COp): +class Softmax(COp[TensorVariable]): r""" Softmax activation function :math:`\\varphi(\\mathbf{x})_j = @@ -494,7 +495,7 @@ def softmax(c, axis=None): return Softmax(axis=axis)(c) -class LogSoftmax(COp): +class LogSoftmax(COp[TensorVariable]): r""" LogSoftmax activation function :math:`\\varphi(\\mathbf{x})_j = diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 135d1e947b..eafa37aaf1 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -788,7 +788,7 @@ def __hash__(self): return hash((type(self), props_values)) -class Subtensor(BaseSubtensor, COp): +class Subtensor(BaseSubtensor, COp[TensorVariable]): """Basic NumPy indexing operator.""" check_input = False @@ -1362,7 +1362,7 @@ def process_slice_component(comp): pprint.assign(Subtensor, SubtensorPrinter()) -class IncSubtensor(BaseSubtensor, COp): +class IncSubtensor(BaseSubtensor, COp[TensorVariable]): """ Increment a subtensor. @@ -1792,7 +1792,7 @@ def _sum_grad_over_bcasted_dims(x, gx): return gx -class AdvancedSubtensor1(COp): +class AdvancedSubtensor1(COp[TensorVariable]): """ Implement x[ilist] where ilist is a vector of integers. @@ -1957,7 +1957,7 @@ def _idx_may_be_invalid(x, idx) -> bool: advanced_subtensor1 = AdvancedSubtensor1() -class AdvancedIncSubtensor1(BaseSubtensor, COp): +class AdvancedIncSubtensor1(BaseSubtensor, COp[TensorVariable]): """ Increments a subtensor using advanced slicing (list of index). @@ -2269,7 +2269,7 @@ def as_tensor_index_variable(idx): return idx -class AdvancedSubtensor(BaseSubtensor, COp): +class AdvancedSubtensor(BaseSubtensor, COp[TensorVariable]): """Implements NumPy's advanced indexing.""" __props__ = ("idx_list",) @@ -2550,7 +2550,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): return type(op)(new_idx_list).make_node(batch_x, *batch_idxs) -class AdvancedIncSubtensor(BaseSubtensor, Op): +class AdvancedIncSubtensor(BaseSubtensor, Op[TensorVariable]): """Increments a subtensor using advanced indexing.""" __props__ = ( diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 2154e4aa1a..b93f28520e 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -5,6 +5,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op from pytensor.link.c.op import COp +from pytensor.scalar.basic import ScalarVariable from pytensor.tensor.type import lscalar from pytensor.tensor.type_other import SliceType from pytensor.tensor.variable import TensorVariable @@ -68,7 +69,7 @@ class TypedListConstant(_typed_list_py_operators, Constant): TypedListType.constant_type = TypedListConstant -class GetItem(COp): +class GetItem(COp[Variable]): # See doc in instance of this Op or function after this class definition. view_map = {0: [0]} __props__ = () @@ -130,7 +131,7 @@ def c_code_cache_version(self): """ -class Append(COp): +class Append(COp[TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -209,7 +210,7 @@ def c_code_cache_version(self): """ -class Extend(COp): +class Extend(COp[TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -293,7 +294,7 @@ def c_code_cache_version_(self): """ -class Insert(COp): +class Insert(COp[TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -379,7 +380,7 @@ def c_code_cache_version(self): """ -class Remove(Op): +class Remove(Op[TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -436,7 +437,7 @@ def __str__(self): """ -class Reverse(COp): +class Reverse(COp[TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -503,7 +504,7 @@ def c_code_cache_version(self): """ -class Index(Op): +class Index(Op[ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -532,7 +533,7 @@ def __str__(self): index_ = Index() -class Count(Op): +class Count(Op[ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -579,7 +580,7 @@ def __str__(self): """ -class Length(COp): +class Length(COp[ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -620,7 +621,7 @@ def c_code_cache_version(self): """ -class MakeList(Op): +class MakeList(Op[TypedListVariable]): __props__ = () def make_node(self, a): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 132b92bff2..07e90819b7 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -4,10 +4,10 @@ from pytensor.graph import Apply, Op from pytensor.graph.basic import Variable from pytensor.tensor.type import TensorType -from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor +from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor -class XOp(Op): +class XOp(Op[XTensorVariable]): """A base class for XOps that shouldn't be materialized""" def perform(self, node, inputs, outputs): From a2bbd09d8112311d82dc227f85b8bf95cf95db74 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Mon, 30 Mar 2026 12:28:34 +0200 Subject: [PATCH 2/4] Add signature for ops with more than 1 output --- pytensor/compile/builders.py | 24 +++++++---- pytensor/compile/ops.py | 4 +- pytensor/graph/basic.py | 36 ++++++++++------ pytensor/graph/op.py | 43 +++++++++++-------- pytensor/link/c/op.py | 9 ++-- pytensor/raise_op.py | 2 +- pytensor/scalar/basic.py | 2 +- pytensor/scan/basic.py | 5 +-- pytensor/scan/op.py | 2 +- pytensor/scan/rewriting.py | 10 +++-- pytensor/sparse/basic.py | 41 +++++++++--------- pytensor/sparse/math.py | 35 ++++++++------- pytensor/sparse/rewriting.py | 6 +-- pytensor/sparse/variable.py | 3 +- .../tensor/_linalg/solve/linear_control.py | 2 +- pytensor/tensor/_linalg/solve/tridiagonal.py | 4 +- pytensor/tensor/basic.py | 26 +++++------ pytensor/tensor/blas.py | 8 ++-- pytensor/tensor/blas_c.py | 2 +- pytensor/tensor/blockwise.py | 12 +++--- pytensor/tensor/einsum.py | 2 +- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/extra_ops.py | 20 ++++----- pytensor/tensor/fft.py | 4 +- pytensor/tensor/fourier.py | 2 +- pytensor/tensor/math.py | 4 +- pytensor/tensor/nlinalg.py | 18 ++++---- pytensor/tensor/optimize.py | 2 +- pytensor/tensor/pad.py | 2 +- pytensor/tensor/random/op.py | 25 ++++++++--- pytensor/tensor/reshape.py | 4 +- pytensor/tensor/shape.py | 8 ++-- pytensor/tensor/signal/conv.py | 4 +- pytensor/tensor/slinalg.py | 24 +++++------ pytensor/tensor/sort.py | 4 +- pytensor/tensor/special.py | 6 +-- pytensor/tensor/subtensor.py | 12 +++--- pytensor/tensor/variable.py | 2 +- pytensor/typed_list/basic.py | 20 ++++----- pytensor/xtensor/basic.py | 2 +- 40 files changed, 241 insertions(+), 202 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 9f64b45e26..7883bd6256 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -19,7 +19,13 @@ ) from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType -from pytensor.graph.op import HasInnerGraph, Op, OpOutputType, io_connection_pattern +from pytensor.graph.op import ( + HasInnerGraph, + Op, + OpDefaultOutputType, + OpOutputsType, + io_connection_pattern, +) from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError @@ -154,7 +160,7 @@ def construct_nominal_fgraph( return fgraph, implicit_shared_inputs, update_d, update_expr -class OpFromGraph(Op, HasInnerGraph, Generic[OpOutputType]): +class OpFromGraph(Op, HasInnerGraph, Generic[OpOutputsType, OpDefaultOutputType]): r""" This creates an `Op` from inputs and outputs lists of variables. The signature is similar to :func:`pytensor.function ` @@ -253,7 +259,7 @@ def rescale_dy(inps, outputs, out_grads): def __init__( self, inputs: list[Variable], - outputs: list[OpOutputType], + outputs: OpOutputsType, *, inline: bool = False, lop_overrides: Union[Callable, "OpFromGraph", None] = None, @@ -716,9 +722,9 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]: def L_op( self, inputs: Sequence[Variable], - outputs: Sequence[OpOutputType], - output_grads: Sequence[OpOutputType], - ) -> list[OpOutputType]: + outputs: Sequence[OpDefaultOutputType], + output_grads: Sequence[OpDefaultOutputType], + ) -> list[OpDefaultOutputType]: disconnected_output_grads = tuple( isinstance(og.type, DisconnectedType) for og in output_grads ) @@ -728,12 +734,12 @@ def L_op( def R_op( self, inputs: Sequence[Variable], - eval_points: OpOutputType | list[OpOutputType], - ) -> list[OpOutputType]: + eval_points: OpDefaultOutputType | list[OpDefaultOutputType], + ) -> list[OpDefaultOutputType]: rop_op = self._build_and_cache_rop_op() return rop_op(*inputs, *eval_points, return_list=True) - def __call__(self, *inputs, **kwargs) -> OpOutputType | list[OpOutputType]: + def __call__(self, *inputs, **kwargs) -> OpOutputsType | OpOutputsType: # The user interface doesn't expect the shared variable inputs of the # inner-graph, but, since `Op.make_node` does (and `Op.__call__` # dispatches to `Op.make_node`), we need to compensate here diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 51398cd7d8..c4f3bf858c 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -9,7 +9,7 @@ import pickle import warnings -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.link.c.op import COp from pytensor.link.c.type import CType @@ -221,7 +221,7 @@ def load_back(mod, name): return obj -class FromFunctionOp(Op): +class FromFunctionOp(Op[tuple[Variable, ...], Variable]): """ Build a basic PyTensor Op around a function. diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 5d1d9f72be..210b07aa5e 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -37,7 +37,8 @@ OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True) _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) -ApplyOutType = TypeVar("ApplyOutType", bound="Variable") +ApplyOutputsType = TypeVar("ApplyOutputsType", bound=tuple["Variable", ...]) +ApplyDefaultOutputType = TypeVar("ApplyDefaultOutputType", bound="Variable") _MOVED_FUNCTIONS = { "walk", @@ -107,7 +108,7 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) -class Apply(Node, Generic[OpType, ApplyOutType]): +class Apply(Node, Generic[OpType, ApplyOutputsType, ApplyDefaultOutputType]): """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -146,7 +147,7 @@ def __init__( self, op: OpType, inputs: Sequence["Variable"], - outputs: Sequence[ApplyOutType], + outputs: ApplyOutputsType, ): if not isinstance(inputs, Sequence): raise TypeError("The inputs of an Apply must be a sequence type") @@ -166,7 +167,8 @@ def __init__( raise TypeError( f"The 'inputs' argument to Apply must contain Variable instances, not {input}" ) - self.outputs: list[ApplyOutType] = [] + self.outputs: ApplyOutputsType + _outputs: list[Any] = [] # filter outputs to make sure each element is a Variable for i, output in enumerate(outputs): if isinstance(output, Variable): @@ -177,11 +179,17 @@ def __init__( raise ValueError( "All output variables passed to Apply must belong to it." ) - self.outputs.append(output) + _outputs.append(output) else: raise TypeError( f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}" ) + # The _outputs will be a list of Variables and we cannot type hint each separately. + # We could use cast(ApplyOutputsType, tuple(_outputs)) to attach the type hint + # information for each output entry, but that could introduce a call overhead + # to cast. + # Instead, we will just ignore the type in this assignment + self.outputs = tuple(_outputs) # type: ignore def __getstate__(self): d = self.__dict__ @@ -193,7 +201,7 @@ def __getstate__(self): d["tag"] = t return d - def default_output(self) -> ApplyOutType: + def default_output(self) -> ApplyDefaultOutputType: """ Returns the default output for this node. @@ -211,12 +219,12 @@ def default_output(self) -> ApplyOutType: do = getattr(self.op, "default_output", None) if do is None: if len(self.outputs) == 1: - return self.outputs[0] + return cast(ApplyDefaultOutputType, self.outputs[0]) else: raise ValueError( f"Multi-output Op {self.op} default_output not specified" ) - return cast(ApplyOutType, self.outputs[do]) + return cast(ApplyDefaultOutputType, self.outputs[do]) def __str__(self): # FIXME: The called function is too complicated for this simple use case. @@ -225,7 +233,9 @@ def __str__(self): def __repr__(self): return str(self) - def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType, ApplyOutType]": + def clone( + self, clone_inner_graph: bool = False + ) -> "Apply[OpType, ApplyOutputsType, ApplyDefaultOutputType]": r"""Clone this `Apply` instance. Parameters @@ -250,14 +260,16 @@ def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType, ApplyOutType] new_op = new_op.clone() # type: ignore cp = self.__class__( - new_op, self.inputs, [output.clone() for output in self.outputs] + new_op, + self.inputs, + cast(ApplyOutputsType, tuple([output.clone() for output in self.outputs])), ) cp.tag = copy(self.tag) return cp def clone_with_new_inputs( self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False - ) -> "Apply[OpType, ApplyOutType]": + ) -> "Apply[OpType, ApplyOutputsType, ApplyDefaultOutputType]": r"""Duplicate this `Apply` instance in a new graph. Parameters @@ -325,7 +337,7 @@ def get_parents(self): return list(self.inputs) @property - def out(self) -> ApplyOutType: + def out(self) -> ApplyDefaultOutputType: """An alias for `self.default_output`""" return self.default_output() diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 354506674f..e8140fb82f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -50,10 +50,11 @@ def is_thunk_type(thunk: ThunkCallableType) -> ThunkType: return res -OpOutputType = TypeVar("OpOutputType", bound=Variable) +OpOutputsType = TypeVar("OpOutputsType", bound=tuple[Variable, ...]) +OpDefaultOutputType = TypeVar("OpDefaultOutputType", bound=Variable) -class Op(MetaObject, Generic[OpOutputType]): +class Op(MetaObject, Generic[OpOutputsType, OpDefaultOutputType]): """A class that models and constructs operations in a graph. A `Op` instance has several responsibilities: @@ -124,7 +125,9 @@ class Op(MetaObject, Generic[OpOutputType]): as nodes with these Ops must be rebuilt even if the input types haven't changed. """ - def make_node(self, *inputs: Variable) -> Apply[Self, OpOutputType]: + def make_node( + self, *inputs: Variable + ) -> Apply[Self, OpOutputsType, OpDefaultOutputType]: """Construct an `Apply` node that represent the application of this operation to the given inputs. This must be implemented by sub-classes. @@ -164,11 +167,13 @@ def make_node(self, *inputs: Variable) -> Apply[Self, OpOutputType]: if inp != out ) ) - return Apply(self, inputs, [cast(OpOutputType, o()) for o in self.otypes]) + return Apply( + self, inputs, cast(OpOutputsType, tuple([o() for o in self.otypes])) + ) def __call__( - self, *inputs: Any, name=None, return_list=False, **kwargs - ) -> OpOutputType | list[OpOutputType]: + self, *inputs: Any, name=None, return_list: bool = False, **kwargs + ) -> OpOutputsType | OpDefaultOutputType | tuple[OpDefaultOutputType]: r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. This method is just a wrapper around :meth:`Op.make_node`. @@ -223,15 +228,15 @@ def __call__( if self.default_output is not None: rval = node.outputs[self.default_output] if return_list: - return [rval] - return rval + return cast(tuple[OpDefaultOutputType], (rval,)) + return cast(OpDefaultOutputType, rval) else: if return_list: - return list(node.outputs) + return cast(OpOutputsType, tuple(node.outputs)) elif len(node.outputs) == 1: - return node.outputs[0] + return cast(OpDefaultOutputType, node.outputs[0]) else: - return node.outputs + return cast(OpOutputsType, tuple(node.outputs)) def __ne__(self, other: Any) -> bool: return not (self == other) @@ -241,8 +246,8 @@ def __ne__(self, other: Any) -> bool: add_tag_trace = staticmethod(add_tag_trace) def grad( - self, inputs: Sequence[Variable], output_grads: Sequence[OpOutputType] - ) -> list[OpOutputType]: + self, inputs: Sequence[Variable], output_grads: Sequence[OpDefaultOutputType] + ) -> list[OpDefaultOutputType]: r"""Construct a graph for the gradient with respect to each input variable. Each returned `Variable` represents the gradient with respect to that @@ -288,9 +293,9 @@ def grad( def L_op( self, inputs: Sequence[Variable], - outputs: Sequence[OpOutputType], - output_grads: Sequence[OpOutputType], - ) -> list[OpOutputType]: + outputs: Sequence[OpDefaultOutputType], + output_grads: Sequence[OpDefaultOutputType], + ) -> list[OpDefaultOutputType]: r"""Construct a graph for the L-operator. The L-operator computes a row vector times the Jacobian. @@ -315,8 +320,10 @@ def L_op( return self.grad(inputs, output_grads) def R_op( - self, inputs: list[Variable], eval_points: OpOutputType | list[OpOutputType] - ) -> list[OpOutputType]: + self, + inputs: list[Variable], + eval_points: OpDefaultOutputType | list[OpDefaultOutputType], + ) -> list[OpDefaultOutputType]: r"""Construct a graph for the R-operator. This method is primarily used by `Rop`. diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 71374bd777..332592a310 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -13,7 +13,8 @@ from pytensor.graph.op import ( ComputeMapType, Op, - OpOutputType, + OpDefaultOutputType, + OpOutputsType, StorageMapType, ThunkType, ) @@ -38,7 +39,7 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType: return res -class COp(Op, CLinkerOp, Generic[OpOutputType]): +class COp(Op, CLinkerOp, Generic[OpOutputsType, OpDefaultOutputType]): """An `Op` with a C implementation.""" def make_c_thunk( @@ -139,7 +140,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): ) -class OpenMPOp(COp, Generic[OpOutputType]): +class OpenMPOp(COp, Generic[OpOutputsType, OpDefaultOutputType]): r"""Base class for `Op`\s using OpenMP. This `Op` will check that the compiler support correctly OpenMP code. @@ -260,7 +261,7 @@ def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]: return define_all, undef_all -class ExternalCOp(COp, Generic[OpOutputType]): +class ExternalCOp(COp, Generic[OpOutputsType, OpDefaultOutputType]): """Class for an `Op` with an external C implementation. One can inherit from this class, provide its constructor with a path to diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 01765b2742..9d305bf6c5 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -24,7 +24,7 @@ def __hash__(self): exception_type = ExceptionType() -class CheckAndRaise(COp[TensorVariable]): +class CheckAndRaise(COp[tuple[TensorVariable], TensorVariable]): """An `Op` that checks conditions and raises an exception if they fail. This `Op` returns its "value" argument if its condition arguments are all diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 7431171d3e..c5b0aeeb7e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1182,7 +1182,7 @@ def _cast_to_promised_scalar_dtype(x, dtype): return getattr(np, dtype)(x) -class ScalarOp(COp[ScalarVariable]): +class ScalarOp(COp[tuple[ScalarVariable], ScalarVariable]): nin = -1 nout = 1 diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 6370067bca..a658a0080d 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -1106,9 +1106,8 @@ def wrap_into_list(x): # to make sure all inputs are tensors. pass scan_inputs += [arg] - scan_outs = local_op(*scan_inputs) - if not isinstance(scan_outs, list | tuple): - scan_outs = [scan_outs] + _scan_outs = local_op(*scan_inputs) + scan_outs = [_scan_outs] if not isinstance(_scan_outs, list | tuple) else _scan_outs ## # Step 9. Figure out which outs are update rules for shared variables # and so on ... diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 60837ae4e3..57a597d994 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -710,7 +710,7 @@ def validate_inner_graph(self): ) -class Scan(Op[Variable], ScanMethodsMixin, HasInnerGraph): +class Scan(Op[tuple[Variable, ...], Variable], ScanMethodsMixin, HasInnerGraph): r"""An `Op` implementing `for` and `while` loops. This `Op` has an "inner-graph" that represents the steps performed during diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index bb1f3a40f9..72e7db6821 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -775,7 +775,7 @@ def add_nitsot_outputs( # Create the Apply node for the scan op new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True) - assert isinstance(new_scan_outs, list) + assert isinstance(new_scan_outs, list | tuple) new_scan_node = new_scan_outs[0].owner assert new_scan_node is not None @@ -948,9 +948,9 @@ def add_requirements(self, fgraph): def attempt_scan_inplace( self, fgraph: FunctionGraph, - node: Apply[Scan, Variable], + node: Apply[Scan, tuple[Variable, ...], Variable], output_indices: list[int], - ) -> Apply | None: + ) -> Apply[Scan, tuple[Variable, ...], Variable] | None: """Attempt to replace a `Scan` node by one which computes the specified outputs inplace. Parameters @@ -1015,7 +1015,9 @@ def attempt_scan_inplace( k: v for k, v in new_op.view_map.items() if k not in destroy_map } - new_node: Apply[Scan, Variable] = new_op.make_node(*inputs) + new_node: Apply[Scan, tuple[Variable, ...], Variable] = new_op.make_node( + *inputs + ) try: fgraph.replace_all_validate_remove( # type: ignore diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 25bddbf89c..a079883225 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -30,7 +30,6 @@ from pytensor.tensor.type import TensorType, ivector, scalar, tensor, vector from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes -from pytensor.tensor.variable import TensorVariable sparse_formats = ["csc", "csr"] @@ -245,7 +244,7 @@ def bsr_matrix(name=None, dtype=None): discrete_dtypes = int_dtypes + uint_dtypes -class CSMProperties(Op[TensorVariable]): +class CSMProperties(Op[tuple["SparseVariable"], "SparseVariable"]): """Create arrays containing all the properties of a given sparse matrix. More specifically, this `Op` extracts the ``.data``, ``.indices``, @@ -361,7 +360,7 @@ def csm_shape(csm): return csm_properties(csm)[3] -class CSM(Op[TensorVariable]): +class CSM(Op[tuple["SparseVariable"], "SparseVariable"]): """Construct a CSM matrix from constituent parts. Notes @@ -505,7 +504,7 @@ def infer_shape(self, fgraph, node, shapes): CSR = CSM("csr") -class CSMGrad(Op[TensorVariable]): +class CSMGrad(Op[tuple["SparseVariable"], "SparseVariable"]): """Compute the gradient of a CSM. Note @@ -592,7 +591,7 @@ def infer_shape(self, fgraph, node, shapes): csm_grad = CSMGrad -class Cast(Op[TensorVariable]): +class Cast(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = ("out_type",) def __init__(self, out_type): @@ -670,7 +669,7 @@ def cast(variable, dtype): return Cast(dtype)(variable) -class DenseFromSparse(Op[TensorVariable]): +class DenseFromSparse(Op[tuple["SparseVariable"], "SparseVariable"]): """Convert a sparse matrix to a dense one. Notes @@ -750,7 +749,7 @@ def infer_shape(self, fgraph, node, shapes): dense_from_sparse = DenseFromSparse() -class SparseFromDense(Op[TensorVariable]): +class SparseFromDense(Op[tuple["SparseVariable"], "SparseVariable"]): """Convert a dense matrix to a sparse matrix.""" __props__ = () @@ -816,7 +815,7 @@ def infer_shape(self, fgraph, node, shapes): csc_from_dense = SparseFromDense("csc") -class GetItemList(Op[TensorVariable]): +class GetItemList(Op[tuple["SparseVariable"], "SparseVariable"]): """Select row of sparse matrix, returning them as a new sparse matrix.""" __props__ = () @@ -863,7 +862,7 @@ def grad(self, inputs, g_outputs): get_item_list = GetItemList() -class GetItemListGrad(Op[TensorVariable]): +class GetItemListGrad(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = () def infer_shape(self, fgraph, node, shapes): @@ -906,7 +905,7 @@ def perform(self, node, inp, outputs): get_item_list_grad = GetItemListGrad() -class GetItem2Lists(Op[TensorVariable]): +class GetItem2Lists(Op[tuple["SparseVariable"], "SparseVariable"]): """Select elements of sparse matrix, returning them in a vector.""" __props__ = () @@ -956,7 +955,7 @@ def grad(self, inputs, g_outputs): get_item_2lists = GetItem2Lists() -class GetItem2ListsGrad(Op[TensorVariable]): +class GetItem2ListsGrad(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = () def infer_shape(self, fgraph, node, shapes): @@ -997,7 +996,7 @@ def perform(self, node, inp, outputs): get_item_2lists_grad = GetItem2ListsGrad() -class GetItem2d(Op[TensorVariable]): +class GetItem2d(Op[tuple["SparseVariable"], "SparseVariable"]): """Implement a subtensor of sparse variable, returning a sparse matrix. If you want to take only one element of a sparse matrix see @@ -1126,7 +1125,7 @@ def perform(self, node, inputs, outputs): get_item_2d = GetItem2d() -class GetItemScalar(Op[TensorVariable]): +class GetItemScalar(Op[tuple["SparseVariable"], "SparseVariable"]): """Subtensor of a sparse variable that takes two scalars as index and returns a scalar. If you want to take a slice of a sparse matrix see `GetItem2d` that returns a @@ -1187,7 +1186,7 @@ def perform(self, node, inputs, outputs): get_item_scalar = GetItemScalar() -class Transpose(Op[TensorVariable]): +class Transpose(Op[tuple["SparseVariable"], "SparseVariable"]): """Transpose of a sparse matrix. Notes @@ -1247,7 +1246,7 @@ def infer_shape(self, fgraph, node, shapes): transpose = Transpose() -class ColScaleCSC(Op[TensorVariable]): +class ColScaleCSC(Op[tuple["SparseVariable"], "SparseVariable"]): # Scale each columns of a sparse matrix by the corresponding # element of a dense vector @@ -1293,7 +1292,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -class RowScaleCSC(Op[TensorVariable]): +class RowScaleCSC(Op[tuple["SparseVariable"], "SparseVariable"]): # Scale each row of a sparse matrix by the corresponding element of # a dense vector @@ -1401,7 +1400,7 @@ def row_scale(x, s): return col_scale(x.T, s).T -class Diag(Op[TensorVariable]): +class Diag(Op[tuple["SparseVariable"], "SparseVariable"]): """Extract the diagonal of a square sparse matrix as a dense vector. Notes @@ -1455,7 +1454,7 @@ def square_diagonal(diag): return CSC(data, indices, indptr, ptb.as_tensor((n, n))) -class EnsureSortedIndices(Op[TensorVariable]): +class EnsureSortedIndices(Op[tuple["SparseVariable"], "SparseVariable"]): """Re-sort indices of a sparse matrix. CSR column indices are not necessarily sorted. Likewise @@ -1540,7 +1539,7 @@ def clean(x): return ensure_sorted_indices(remove0(x)) -class Stack(Op[TensorVariable]): +class Stack(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = ("format", "dtype") def __init__(self, format=None, dtype=None): @@ -1751,7 +1750,7 @@ def vstack(blocks, format=None, dtype=None): return VStack(format=format, dtype=dtype)(*blocks) -class Remove0(Op[TensorVariable]): +class Remove0(Op[tuple["SparseVariable"], "SparseVariable"]): """Remove explicit zeros from a sparse matrix. Notes @@ -1808,7 +1807,7 @@ def infer_shape(self, fgraph, node, i0_shapes): remove0 = Remove0() -class ConstructSparseFromList(Op[TensorVariable]): +class ConstructSparseFromList(Op[tuple["SparseVariable"], "SparseVariable"]): """Constructs a sparse matrix out of a list of 2-D matrix rows. Notes diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index fa942b56e6..d64a9c523a 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -15,7 +15,6 @@ from pytensor.sparse.type import SparseTensorType from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor -from pytensor.tensor.variable import TensorVariable def structured_elemwise(tensor_op): @@ -255,7 +254,7 @@ def conjugate(x): structured_conjugate = conj = conjugate -class SpSum(Op[TensorVariable]): +class SpSum(Op[tuple["SparseVariable"], "SparseVariable"]): """ WARNING: judgement call... @@ -375,7 +374,7 @@ def sp_sum(x, axis=None, sparse_grad=False): return SpSum(axis, sparse_grad)(x) -class AddSS(Op[TensorVariable]): +class AddSS(Op[tuple["SparseVariable"], "SparseVariable"]): # add(sparse, sparse). # see the doc of add() for more detail. __props__ = () @@ -412,7 +411,7 @@ def infer_shape(self, fgraph, node, shapes): add_s_s = AddSS() -class AddSSData(Op[TensorVariable]): +class AddSSData(Op[tuple["SparseVariable"], "SparseVariable"]): """Add two sparse matrices assuming they have the same sparsity pattern. Notes @@ -473,7 +472,7 @@ def infer_shape(self, fgraph, node, ins_shapes): add_s_s_data = AddSSData() -class AddSD(Op[TensorVariable]): +class AddSD(Op[tuple["SparseVariable"], "SparseVariable"]): # add(sparse, sparse). # see the doc of add() for more detail. __props__ = () @@ -515,7 +514,7 @@ def infer_shape(self, fgraph, node, shapes): add_s_d = AddSD() -class StructuredAddSV(Op[TensorVariable]): +class StructuredAddSV(Op[tuple["SparseVariable"], "SparseVariable"]): """Structured addition of a sparse matrix and a dense vector. The elements of the vector are only added to the corresponding @@ -667,7 +666,7 @@ def sub(x, y): sub.__doc__ = subtract.__doc__ -class SparseSparseMultiply(Op[TensorVariable]): +class SparseSparseMultiply(Op[tuple["SparseVariable"], "SparseVariable"]): # mul(sparse, sparse) # See the doc of mul() for more detail __props__ = () @@ -705,7 +704,7 @@ def infer_shape(self, fgraph, node, shapes): mul_s_s = SparseSparseMultiply() -class SparseDenseMultiply(Op[TensorVariable]): +class SparseDenseMultiply(Op[tuple["SparseVariable"], "SparseVariable"]): # mul(sparse, dense) # See the doc of mul() for more detail __props__ = () @@ -794,7 +793,7 @@ def infer_shape(self, fgraph, node, shapes): mul_s_d = SparseDenseMultiply() -class SparseDenseVectorMultiply(Op[TensorVariable]): +class SparseDenseVectorMultiply(Op[tuple["SparseVariable"], "SparseVariable"]): """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. Notes @@ -942,7 +941,7 @@ def mul(x, y): mul.__doc__ = multiply.__doc__ -class __ComparisonOpSS(Op[TensorVariable]): +class __ComparisonOpSS(Op[tuple["SparseVariable"], "SparseVariable"]): """ Used as a superclass for all comparisons between two sparses matrices. @@ -992,7 +991,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -class __ComparisonOpSD(Op[TensorVariable]): +class __ComparisonOpSD(Op[tuple["SparseVariable"], "SparseVariable"]): """ Used as a superclass for all comparisons between sparse and dense matrix. @@ -1196,7 +1195,7 @@ def comparison(self, x, y): ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d) -class TrueDot(Op[TensorVariable]): +class TrueDot(Op[tuple["SparseVariable"], "SparseVariable"]): # TODO # Simplify code by splitting into DotSS and DotSD. @@ -1336,7 +1335,7 @@ def true_dot(x, y, grad_preserves_dense=True): return psb.transpose(TrueDot(grad_preserves_dense)(y.T, x.T)) -class StructuredDot(Op[TensorVariable]): +class StructuredDot(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = () def make_node(self, a, b): @@ -1467,7 +1466,7 @@ def structured_dot(x, y): return _structured_dot(y.T, x.T).T -class StructuredDotGradCSC(COp[TensorVariable]): +class StructuredDotGradCSC(COp[tuple["SparseVariable"], "SparseVariable"]): # Op that produces the grad of StructuredDot. # :param a_indices: Matrix indices @@ -1602,7 +1601,7 @@ def infer_shape(self, fgraph, node, shapes): sdg_csc = StructuredDotGradCSC() -class StructuredDotGradCSR(COp[TensorVariable]): +class StructuredDotGradCSR(COp[tuple["SparseVariable"], "SparseVariable"]): # Op that produces the grad of StructuredDot. # :param a_indices: Matrix indices @@ -1759,7 +1758,7 @@ def structured_dot_grad(sparse_A, dense_B, ga): raise NotImplementedError() -class SamplingDot(Op[TensorVariable]): +class SamplingDot(Op[tuple["SparseVariable"], "SparseVariable"]): """Compute the dot product ``dot(x, y.T) = z`` for only a subset of `z`. This is equivalent to ``p * (x . y.T)`` where ``*`` is the element-wise @@ -1835,7 +1834,7 @@ def infer_shape(self, fgraph, node, ins_shapes): sampling_dot = SamplingDot() -class Dot(Op[TensorVariable]): +class Dot(Op[tuple["SparseVariable"], "SparseVariable"]): __props__ = () def __str__(self): @@ -1986,7 +1985,7 @@ def dot(x, y): return _dot(x, y) -class Usmm(Op[TensorVariable]): +class Usmm(Op[tuple["SparseVariable"], "SparseVariable"]): """Computes the dense matrix resulting from ``alpha * x @ y + z``. Notes diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index d9ffed5f3e..9603c2b9aa 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -15,13 +15,13 @@ from pytensor.link.c.op import COp, _NoPythonCOp from pytensor.sparse.basic import csm_properties from pytensor.sparse.math import usmm +from pytensor.sparse.variable import SparseVariable from pytensor.tensor import blas from pytensor.tensor.basic import as_tensor_variable, cast from pytensor.tensor.math import mul, neg, sub from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.shape import shape, specify_shape from pytensor.tensor.type import TensorType, tensor -from pytensor.tensor.variable import TensorVariable _is_sparse_variable = sparse._is_sparse_variable @@ -242,7 +242,7 @@ def local_addsd_ccode(fgraph, node): ) -class StructuredDotCSC(COp[TensorVariable]): +class StructuredDotCSC(COp[tuple[SparseVariable], SparseVariable]): """ Structured Dot CSC is like `dot`, except that only the gradient wrt non-zero elements of a sparse matrix are calculated and propagated. @@ -439,7 +439,7 @@ def c_code_cache_version(self): sd_csc = StructuredDotCSC() -class StructuredDotCSR(COp[TensorVariable]): +class StructuredDotCSR(COp[tuple[SparseVariable], SparseVariable]): """ Structured Dot CSR is like dot, except that only the gradient wrt non-zero elements of a sparse matrix diff --git a/pytensor/sparse/variable.py b/pytensor/sparse/variable.py index 7fbe9eb366..e4fa5a0e17 100644 --- a/pytensor/sparse/variable.py +++ b/pytensor/sparse/variable.py @@ -72,7 +72,8 @@ def to_dense(self, *args, **kwargs): class _sparse_py_operators: T = property( - lambda self: transpose(self), doc="Return aliased transpose of self (read-only)" + lambda self: transpose(self), # type: ignore + doc="Return aliased transpose of self (read-only)", ) def astype(self, dtype): diff --git a/pytensor/tensor/_linalg/solve/linear_control.py b/pytensor/tensor/_linalg/solve/linear_control.py index f797a1d90f..94f419007c 100644 --- a/pytensor/tensor/_linalg/solve/linear_control.py +++ b/pytensor/tensor/_linalg/solve/linear_control.py @@ -20,7 +20,7 @@ from pytensor.tensor.variable import TensorVariable -class TRSYL(Op[TensorVariable]): +class TRSYL(Op[tuple[TensorVariable], TensorVariable]): """ Wrapper around LAPACK's `trsyl` function to solve the Sylvester equation: diff --git a/pytensor/tensor/_linalg/solve/tridiagonal.py b/pytensor/tensor/_linalg/solve/tridiagonal.py index d0cfe43073..a1ddc3f9ae 100644 --- a/pytensor/tensor/_linalg/solve/tridiagonal.py +++ b/pytensor/tensor/_linalg/solve/tridiagonal.py @@ -15,7 +15,7 @@ from pytensor.tensor import TensorLike -class LUFactorTridiagonal(Op[TensorVariable]): +class LUFactorTridiagonal(Op[tuple[TensorVariable], TensorVariable]): """Compute LU factorization of a tridiagonal matrix (lapack gttrf)""" __props__ = ( @@ -89,7 +89,7 @@ def perform(self, node, inputs, output_storage): output_storage[4][0] = ipiv -class SolveLUFactorTridiagonal(Op[TensorVariable]): +class SolveLUFactorTridiagonal(Op[tuple[TensorVariable], TensorVariable]): """Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs).""" __props__ = ("b_ndim", "overwrite_b", "transposed") diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index fd245ddb30..145a277f9c 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -616,7 +616,7 @@ def get_scalar_constant_value( ) -class TensorFromScalar(COp[TensorVariable]): +class TensorFromScalar(COp[tuple[TensorVariable], TensorVariable]): __props__ = () def make_node(self, s): @@ -949,7 +949,7 @@ def ones(shape, dtype=None) -> TensorVariable: return alloc(np.array(1, dtype=dtype), *shape) -class Nonzero(Op[TensorVariable]): +class Nonzero(Op[tuple[TensorVariable], TensorVariable]): """ Return the indices of the elements that are non-zero. @@ -1340,7 +1340,7 @@ def triu_indices_from( return triu_indices(a.shape[0], k=k, m=a.shape[1]) -class Eye(Op[TensorVariable]): +class Eye(Op[tuple[TensorVariable], TensorVariable]): _output_type_depends_on_input_value = True __props__ = ("dtype",) @@ -1534,7 +1534,7 @@ def check_type(s): return sh, static_shape -class Alloc(COp[TensorVariable]): +class Alloc(COp[tuple[TensorVariable], TensorVariable]): """Create a `TensorVariable` from an initial value and a desired shape. Usage: @@ -1879,7 +1879,7 @@ def full_like( return fill(a, fill_value) -class MakeVector(COp[TensorVariable]): +class MakeVector(COp[tuple[TensorVariable], TensorVariable]): """Concatenate a number of scalars together into a vector. This is a simple version of stack() that introduces far less cruft @@ -2060,7 +2060,7 @@ def register_transfer(fn): identity = tensor_copy -class Default(Op[TensorVariable]): +class Default(Op[tuple[TensorVariable], TensorVariable]): """ Takes an input x and a default value. @@ -2191,7 +2191,7 @@ def split(x, splits_size, *, n_splits=None, axis=0): return Split(n_splits)(x, axis, splits_size) -class Split(COp[TensorVariable]): +class Split(COp[tuple[TensorVariable], TensorVariable]): """Partition a `TensorVariable` along some axis. Examples @@ -2428,7 +2428,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ -class Join(COp[TensorVariable]): +class Join(COp[tuple[TensorVariable], TensorVariable]): r""" Concatenate multiple `TensorVariable`\s along some axis. @@ -3249,7 +3249,7 @@ def tile( return A_replicated.reshape(tiled_shape) -class ARange(COp[TensorVariable]): +class ARange(COp[tuple[TensorVariable], TensorVariable]): """Create an array containing evenly spaced values within a given interval. Parameters and behaviour are the same as numpy.arange(). @@ -3536,7 +3536,7 @@ def __getitem__(self, *args): ogrid = _nd_grid(sparse=True) -class PermuteRowElements(Op[TensorVariable]): +class PermuteRowElements(Op[tuple[TensorVariable], TensorVariable]): """Permute the elements of each row (inner-most dim) of a tensor. A permutation will be applied to every row (vector) of the input tensor x. @@ -3746,7 +3746,7 @@ def inverse_permutation(perm): ) -class ExtractDiag(COp[TensorVariable]): +class ExtractDiag(COp[tuple[TensorVariable], TensorVariable]): """ Return specified diagonals. @@ -4261,7 +4261,7 @@ def choose(a, choices, mode="raise"): return Choose(mode)(a, choices) -class Choose(Op[TensorVariable]): +class Choose(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("mode",) def __init__(self, mode): @@ -4326,7 +4326,7 @@ def perform(self, node, inputs, outputs): z[0] = np.choose(a, choice, mode=self.mode) -class AllocEmpty(COp[TensorVariable]): +class AllocEmpty(COp[tuple[TensorVariable], TensorVariable]): """Implement Alloc on the cpu, but without initializing memory.""" _output_type_depends_on_input_value = True diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index daa5cdd16f..9fbb09be7a 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -151,7 +151,7 @@ def must_initialize_y_gemv(): must_initialize_y_gemv._result = None # type: ignore -class Gemv(Op[TensorVariable]): +class Gemv(Op[tuple[TensorVariable], TensorVariable]): """ expression is beta * y + alpha * A x @@ -257,7 +257,7 @@ def infer_shape(self, fgraph, node, input_shapes): gemv = gemv_no_inplace -class Ger(Op[TensorVariable]): +class Ger(Op[tuple[TensorVariable], TensorVariable]): """ BLAS defines general rank-1 update GER as A <- A + alpha x y' @@ -469,7 +469,7 @@ def _ldflags( return rval -class GemmRelated(COp[TensorVariable]): +class GemmRelated(COp[tuple[TensorVariable], TensorVariable]): """Base class for Gemm and Dot22. This class provides a kind of templated gemm Op. @@ -1305,7 +1305,7 @@ def c_code_cache_version(self): _dot22scalar = Dot22Scalar() -class BatchedDot(COp[TensorVariable]): +class BatchedDot(COp[tuple[TensorVariable], TensorVariable]): """ Computes a batch matrix-matrix dot with tensor3 variables diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 548bc45623..52658a21b6 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -11,7 +11,7 @@ from pytensor.tensor.variable import TensorVariable -class BaseBLAS(COp[TensorVariable]): +class BaseBLAS(COp[tuple[TensorVariable], TensorVariable]): def c_libraries(self, **kwargs): return ldflags() diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 84ac7d6543..84c1cfb6b1 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any, Literal, cast, overload +from typing import Any, Generic, Literal, cast, overload import numpy as np from numpy import broadcast_shapes, empty @@ -9,7 +9,7 @@ from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType -from pytensor.graph.op import Op +from pytensor.graph.op import Op, OpDefaultOutputType, OpOutputsType from pytensor.graph.replace import ( _vectorize_node, _vectorize_not_needed, @@ -151,7 +151,7 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ ) -class Blockwise(COp[TensorVariable]): +class Blockwise(COp[tuple[TensorVariable], TensorVariable]): """Generalizes a core `Op` to work with batched dimensions. TODO: C implementation? @@ -589,7 +589,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: _vectorize_node.register(Blockwise, _vectorize_not_needed) -class OpWithCoreShape(OpFromGraph): +class OpWithCoreShape(OpFromGraph, Generic[OpOutputsType, OpDefaultOutputType]): """Generalizes an `Op` to include core shape as an additional input.""" def __init__(self, *args, on_unused_input="ignore", **kwargs): @@ -600,7 +600,9 @@ def __init__(self, *args, on_unused_input="ignore", **kwargs): return super().__init__(*args, on_unused_input=on_unused_input, **kwargs) -class BlockwiseWithCoreShape(OpWithCoreShape): +class BlockwiseWithCoreShape( + OpWithCoreShape, Generic[OpOutputsType, OpDefaultOutputType] +): """Generalizes a Blockwise `Op` to include a core shape parameter.""" @property diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 88ef39042b..61a53e804d 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -35,7 +35,7 @@ CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str] -class Einsum(OpFromGraph[TensorVariable]): +class Einsum(OpFromGraph[tuple[TensorVariable], TensorVariable]): """ Wrapper Op for Einsum graphs diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index aa72145e62..5b02d4a949 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1134,7 +1134,7 @@ def outer(self, x, y): return self(x_, y_) -class CAReduce(COp[TensorVariable]): +class CAReduce(COp[tuple[TensorVariable], TensorVariable]): """Reduces a scalar operation along specified axes. The scalar op should be both commutative and associative. diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 31c95e44a8..0d6c50dfdf 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -50,7 +50,7 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH -class CpuContiguous(COp[TensorVariable]): +class CpuContiguous(COp[tuple[TensorVariable], TensorVariable]): """ Check to see if the input is c-contiguous. @@ -109,7 +109,7 @@ def c_code_cache_version(self): cpu_contiguous = CpuContiguous() -class SearchsortedOp(COp[TensorVariable]): +class SearchsortedOp(COp[tuple[TensorVariable], TensorVariable]): """Wrapper for ``numpy.searchsorted``. For full documentation, see :func:`searchsorted`. @@ -284,7 +284,7 @@ def searchsorted(x, v, side="left", sorter=None): return SearchsortedOp(side=side)(x, v, sorter) -class CumOp(COp[TensorVariable]): +class CumOp(COp[tuple[TensorVariable], TensorVariable]): # See function cumsum/cumprod for docstring __props__ = ("axis", "mode") @@ -620,7 +620,7 @@ def compress(condition, x, axis=None): return _x.take(indices, axis=axis) -class Repeat(Op[TensorVariable]): +class Repeat(Op[tuple[TensorVariable], TensorVariable]): # See the repeat function for docstring __props__ = ("axis",) @@ -831,7 +831,7 @@ def repeat( return broadcast_a.reshape(repeat_shape) -class Bartlett(Op[TensorVariable]): +class Bartlett(Op[tuple[TensorVariable], TensorVariable]): # See function bartlett for docstring __props__ = () @@ -888,7 +888,7 @@ def bartlett(M): return bartlett_(M) -class FillDiagonal(Op[TensorVariable]): +class FillDiagonal(Op[tuple[TensorVariable], TensorVariable]): # See function fill_diagonal for docstring __props__ = () @@ -989,7 +989,7 @@ def fill_diagonal(a, val): return fill_diagonal_(a, val) -class FillDiagonalOffset(Op[TensorVariable]): +class FillDiagonalOffset(Op[tuple[TensorVariable], TensorVariable]): # See function fill_diagonal_offset for docstring __props__ = () @@ -1161,7 +1161,7 @@ def to_one_hot(y, nb_class, dtype=None): return ret -class Unique(Op[TensorVariable]): +class Unique(Op[tuple[TensorVariable], TensorVariable]): """ Wraps `numpy.unique`. @@ -1283,7 +1283,7 @@ def unique( return Unique(return_index, return_inverse, return_counts, axis)(ar) -class UnravelIndex(Op[TensorVariable]): +class UnravelIndex(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("order",) def __init__(self, order="C"): @@ -1360,7 +1360,7 @@ def unravel_index(indices, dims, order="C"): return tuple(res) -class RavelMultiIndex(Op[TensorVariable]): +class RavelMultiIndex(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("mode", "order") def __init__(self, mode="raise", order="C"): diff --git a/pytensor/tensor/fft.py b/pytensor/tensor/fft.py index fa18312ff8..7e4bddf691 100644 --- a/pytensor/tensor/fft.py +++ b/pytensor/tensor/fft.py @@ -10,7 +10,7 @@ from pytensor.tensor.variable import TensorVariable -class RFFTOp(Op[TensorVariable]): +class RFFTOp(Op[tuple[TensorVariable], TensorVariable]): __props__ = () def output_type(self, inp): @@ -70,7 +70,7 @@ def connection_pattern(self, node): rfft_op = RFFTOp() -class IRFFTOp(Op[TensorVariable]): +class IRFFTOp(Op[tuple[TensorVariable], TensorVariable]): __props__ = () def output_type(self, inp): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 34c185ee89..0e1e69f50c 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -19,7 +19,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable -class Fourier(Op[TensorVariable]): +class Fourier(Op[tuple[TensorVariable], TensorVariable]): """ WARNING: for officially supported FFTs, use pytensor.tensor.fft, which provides real-input FFTs. Gradients are supported. diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 1baa312f01..d30ce63491 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -139,7 +139,7 @@ def _allclose(a, b, rtol=None, atol=None): return np.allclose(a, b, atol=atol_, rtol=rtol_) -class Argmax(COp[TensorVariable]): +class Argmax(COp[tuple[TensorVariable], TensorVariable]): """ Calculate the argmax over a given axis or over all axes. """ @@ -3019,7 +3019,7 @@ def clip(x, min, max): pprint.assign(pow, printing.OperatorPrinter("**", 1, "right")) -class Dot(Op[TensorVariable]): +class Dot(Op[tuple[TensorVariable], TensorVariable]): """ Computes the dot product of two matrices variables diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 7267fa9afb..b22207450f 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -29,7 +29,7 @@ from pytensor.tensor.variable import TensorVariable -class MatrixPinv(Op[TensorVariable]): +class MatrixPinv(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("hermitian",) gufunc_signature = "(m,n)->(n,m)" @@ -209,7 +209,7 @@ def trace(X): return diagonal(X).sum() -class Det(Op[TensorVariable]): +class Det(Op[tuple[TensorVariable], TensorVariable]): """ Matrix determinant. Input should be a square matrix. @@ -260,7 +260,7 @@ def __str__(self): det = Blockwise(Det()) -class SLogDet(Op[TensorVariable]): +class SLogDet(Op[tuple[TensorVariable], TensorVariable]): """ Compute the log determinant and its sign of the matrix. Input should be a square matrix. """ @@ -324,7 +324,7 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) -class Eig(Op[TensorVariable]): +class Eig(Op[tuple[TensorVariable], TensorVariable]): """ Compute the eigenvalues and right eigenvectors of a square array. """ @@ -466,7 +466,7 @@ def _zero_disconnected(outputs, grads): return l -class EighGrad(Op[TensorVariable]): +class EighGrad(Op[tuple[TensorVariable], TensorVariable]): """ Gradient of an eigensystem of a Hermitian matrix. @@ -536,7 +536,7 @@ def eigh(a, UPLO="L"): return Eigh(UPLO)(a) -class SVD(Op[TensorVariable]): +class SVD(Op[tuple[TensorVariable], TensorVariable]): """ Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V @@ -742,7 +742,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): return Blockwise(SVD(full_matrices, compute_uv))(a) -class Lstsq(Op[TensorVariable]): +class Lstsq(Op[tuple[TensorVariable], TensorVariable]): __props__ = () def make_node(self, x, y, rcond): @@ -1013,7 +1013,7 @@ def norm( ) -class TensorInv(Op[TensorVariable]): +class TensorInv(Op[tuple[TensorVariable], TensorVariable]): """ Class wrapper for tensorinv() function; PyTensor utilization of numpy.linalg.tensorinv; @@ -1071,7 +1071,7 @@ def tensorinv(a, ind=2): return TensorInv(ind)(a) -class TensorSolve(Op[TensorVariable]): +class TensorSolve(Op[tuple[TensorVariable], TensorVariable]): """ PyTensor utilization of numpy.linalg.tensorsolve Class wrapper for tensorsolve function. diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 4760dc2bc7..6913e5424b 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -163,7 +163,7 @@ def _depends_only_on_constants(var: Variable) -> bool: ] -class ScipyWrapperOp(Op[TensorVariable], HasInnerGraph): +class ScipyWrapperOp(Op[tuple[TensorVariable], TensorVariable], HasInnerGraph): """Shared logic for scipy optimization ops""" def build_fn(self): diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py index 441e21c882..704b22b6f7 100644 --- a/pytensor/tensor/pad.py +++ b/pytensor/tensor/pad.py @@ -413,7 +413,7 @@ def _reflect_inner(i, x, x_flipped, padding_left): return x -class Pad(OpFromGraph[TensorVariable]): +class Pad(OpFromGraph[tuple[TensorVariable], TensorVariable]): """ Wrapper Op for Pad graphs """ diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index c5f3d60ff4..612a40a76f 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,14 +1,14 @@ import abc import warnings from collections.abc import Sequence -from typing import Any, cast +from typing import Any, Generic, Self, cast import numpy as np import pytensor from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations -from pytensor.graph.op import Op +from pytensor.graph.op import Op, OpDefaultOutputType, OpOutputsType from pytensor.graph.replace import _vectorize_node from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( @@ -26,6 +26,7 @@ explicit_expand_dims, normalize_size_param, ) +from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import NoneConst, NoneTypeT @@ -33,7 +34,9 @@ from pytensor.tensor.variable import TensorVariable -class RNGConsumerOp(Op[TensorVariable]): +class RNGConsumerOp( + Op[tuple[RandomGeneratorSharedVariable, TensorVariable], TensorVariable] +): """Baseclass for Ops that consume RNGs.""" @abc.abstractmethod @@ -443,8 +446,8 @@ def R_op(self, inputs, eval_points): return [None for i in eval_points] -class AbstractRNGConstructor(Op): - def make_node(self, seed=None): +class AbstractRNGConstructor(Op, Generic[OpOutputsType, OpDefaultOutputType]): + def make_node(self, seed=None) -> Apply[Self, OpOutputsType, OpDefaultOutputType]: if seed is None: seed = NoneConst elif isinstance(seed, Variable) and isinstance(seed.type, NoneTypeT): @@ -462,7 +465,11 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed) -class DefaultGeneratorMakerOp(AbstractRNGConstructor): +class DefaultGeneratorMakerOp( + AbstractRNGConstructor[ + tuple[RandomGeneratorSharedVariable], RandomGeneratorSharedVariable + ] +): random_type = RandomGeneratorType() random_constructor = "default_rng" @@ -496,7 +503,11 @@ def vectorize_random_variable( return op.make_node(rng, size, *dist_params) -class RandomVariableWithCoreShape(OpWithCoreShape): +class RandomVariableWithCoreShape( + OpWithCoreShape[ + tuple[RandomGeneratorSharedVariable, TensorVariable], TensorVariable + ] +): """Generalizes a random variable `Op` to include a core shape parameter.""" @property diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index 0ee6446083..767b112656 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -23,7 +23,7 @@ ) -class JoinDims(Op[TensorVariable]): +class JoinDims(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("start_axis", "n_axes") view_map = {0: [0]} @@ -154,7 +154,7 @@ def join_dims( return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value] -class SplitDims(Op[TensorVariable]): +class SplitDims(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("axis",) view_map = {0: [0]} diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index e30841c2c8..b1e1767a1e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -50,7 +50,7 @@ def register_shape_c_code(type, code, version=()): Shape.c_code_and_version[type] = (code, version) -class Shape(COp[TensorVariable]): +class Shape(COp[tuple[TensorVariable], TensorVariable]): """ L{Op} to return the shape of a matrix. @@ -198,7 +198,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: return res -class Shape_i(COp[TensorVariable]): +class Shape_i(COp[tuple[TensorVariable], TensorVariable]): """ L{Op} to return the shape of a matrix. @@ -380,7 +380,7 @@ def register_shape_i_c_code(typ, code, check_input, version=()): Shape_i.c_code_and_version[typ] = (code, check_input, version) -class SpecifyShape(COp[TensorVariable]): +class SpecifyShape(COp[tuple[TensorVariable], TensorVariable]): """ L{Op} that puts into the graph the user-provided shape. @@ -625,7 +625,7 @@ def _vectorize_specify_shape(op, node, x, *shape): return specify_shape(x, new_shape).owner -class Reshape(COp[TensorVariable]): +class Reshape(COp[tuple[TensorVariable], TensorVariable]): """Perform a reshape operation of the input x to the new shape shp. The number of dimensions to which to reshape to (ndim) must be known at graph build time. diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 062f0cdb47..6f719fd3a1 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -113,7 +113,7 @@ def L_op(self, inputs, outputs, output_grads): ] -class Convolve1d(AbstractConvolveNd, COp[TensorVariable]): # type: ignore[misc] +class Convolve1d(AbstractConvolveNd, COp[tuple[TensorVariable], TensorVariable]): # type: ignore[misc] __props__ = () ndim = 1 @@ -246,7 +246,7 @@ def convolve1d( return type_cast(TensorVariable, _blockwise_convolve_1d(in1, in2, full_mode)) -class Convolve2d(AbstractConvolveNd, Op[TensorVariable]): # type: ignore[misc] +class Convolve2d(AbstractConvolveNd, Op[tuple[TensorVariable], TensorVariable]): # type: ignore[misc] __props__ = ("method",) # type: ignore[assignment] ndim = 2 diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index da40ac99db..15a4f4e1f5 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class Cholesky(Op[TensorVariable]): +class Cholesky(Op[tuple[TensorVariable], TensorVariable]): # TODO: LAPACK wrapper with in-place behavior, for solve also __props__ = ("lower", "overwrite_a") @@ -210,7 +210,7 @@ def cholesky( return res -class SolveBase(Op[TensorVariable]): +class SolveBase(Op[tuple[TensorVariable], TensorVariable]): """Base class for `scipy.linalg` matrix equation solvers.""" __props__: tuple[str, ...] = ( @@ -412,7 +412,7 @@ def cho_solve( return Blockwise(CholeskySolve(lower=lower, b_ndim=b_ndim))(A, b) -class LU(Op[TensorVariable]): +class LU(Op[tuple[TensorVariable], TensorVariable]): """Decompose a matrix into lower and upper triangular matrices.""" __props__ = ("permute_l", "overwrite_a", "p_indices") @@ -601,7 +601,7 @@ def lu( ) -class PivotToPermutations(Op[TensorVariable]): +class PivotToPermutations(Op[tuple[TensorVariable], TensorVariable]): gufunc_signature = "(x)->(x)" __props__ = ("inverse",) @@ -634,7 +634,7 @@ def pivot_to_permutation(p: TensorLike, inverse=False): return PivotToPermutations(inverse=inverse)(p) -class LUFactor(Op[TensorVariable]): +class LUFactor(Op[tuple[TensorVariable], TensorVariable]): __props__ = ("overwrite_a",) gufunc_signature = "(m,m)->(m,m),(m)" @@ -1127,7 +1127,7 @@ def solve( )(a, b) -class Eigvalsh(Op[TensorVariable]): +class Eigvalsh(Op[tuple[TensorVariable], TensorVariable]): """ Generalized eigenvalues of a Hermitian positive definite eigensystem. @@ -1174,7 +1174,7 @@ def infer_shape(self, fgraph, node, shapes): return [(n,)] -class EigvalshGrad(Op[TensorVariable]): +class EigvalshGrad(Op[tuple[TensorVariable], TensorVariable]): """ Gradient of generalized eigenvalues of a Hermitian positive definite eigensystem. @@ -1234,7 +1234,7 @@ def eigvalsh(a, b, lower=True): return Eigvalsh(lower)(a, b) -class Expm(Op[TensorVariable]): +class Expm(Op[tuple[TensorVariable], TensorVariable]): """ Compute the matrix exponential of a square array. """ @@ -1297,7 +1297,7 @@ def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) -class BaseBlockDiagonal(Op[TensorVariable]): +class BaseBlockDiagonal(Op[tuple[TensorVariable], TensorVariable]): __props__: tuple[str, ...] = ("n_inputs",) def __init__(self, n_inputs): @@ -1408,7 +1408,7 @@ def block_diag(*matrices: TensorVariable): return _block_diagonal_matrix(*matrices) -class QR(Op[TensorVariable]): +class QR(Op[tuple[TensorVariable], TensorVariable]): """ QR Decomposition """ @@ -1778,7 +1778,7 @@ def qr( return Blockwise(QR(mode=mode, pivoting=pivoting, overwrite_a=False))(A) -class Schur(Op[TensorVariable]): +class Schur(Op[tuple[TensorVariable], TensorVariable]): """ Schur Decomposition """ @@ -1965,7 +1965,7 @@ def schur( return Blockwise(Schur(output=output, sort=sort))(A) # type: ignore[return-value] -class QZ(Op[TensorVariable]): +class QZ(Op[tuple[TensorVariable], TensorVariable]): """ QZ Decomposition """ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index 7ccbe48490..7045c514df 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -29,7 +29,7 @@ def _parse_sort_args(kind: KIND | None, order, stable: bool | None) -> KIND: return kind -class SortOp(Op[TensorVariable]): +class SortOp(Op[tuple[TensorVariable], TensorVariable]): """ This class is a wrapper for numpy sort function. @@ -154,7 +154,7 @@ def sort( return SortOp(kind)(a, axis) -class ArgSortOp(Op[TensorVariable]): +class ArgSortOp(Op[tuple[TensorVariable], TensorVariable]): """ This class is a wrapper for numpy argsort function. diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index 9a8ded48da..91ebcdae8f 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -12,7 +12,7 @@ from pytensor.tensor.variable import TensorVariable -class SoftmaxGrad(COp[TensorVariable]): +class SoftmaxGrad(COp[tuple[TensorVariable], TensorVariable]): """ Gradient wrt x of the Softmax Op. @@ -240,7 +240,7 @@ def c_code(self, node, name, inp, out, sub): ) -class Softmax(COp[TensorVariable]): +class Softmax(COp[tuple[TensorVariable], TensorVariable]): r""" Softmax activation function :math:`\\varphi(\\mathbf{x})_j = @@ -495,7 +495,7 @@ def softmax(c, axis=None): return Softmax(axis=axis)(c) -class LogSoftmax(COp[TensorVariable]): +class LogSoftmax(COp[tuple[TensorVariable], TensorVariable]): r""" LogSoftmax activation function :math:`\\varphi(\\mathbf{x})_j = diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index eafa37aaf1..d3c8d07613 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -788,7 +788,7 @@ def __hash__(self): return hash((type(self), props_values)) -class Subtensor(BaseSubtensor, COp[TensorVariable]): +class Subtensor(BaseSubtensor, COp[tuple[TensorVariable], TensorVariable]): """Basic NumPy indexing operator.""" check_input = False @@ -1362,7 +1362,7 @@ def process_slice_component(comp): pprint.assign(Subtensor, SubtensorPrinter()) -class IncSubtensor(BaseSubtensor, COp[TensorVariable]): +class IncSubtensor(BaseSubtensor, COp[tuple[TensorVariable], TensorVariable]): """ Increment a subtensor. @@ -1792,7 +1792,7 @@ def _sum_grad_over_bcasted_dims(x, gx): return gx -class AdvancedSubtensor1(COp[TensorVariable]): +class AdvancedSubtensor1(COp[tuple[TensorVariable], TensorVariable]): """ Implement x[ilist] where ilist is a vector of integers. @@ -1957,7 +1957,7 @@ def _idx_may_be_invalid(x, idx) -> bool: advanced_subtensor1 = AdvancedSubtensor1() -class AdvancedIncSubtensor1(BaseSubtensor, COp[TensorVariable]): +class AdvancedIncSubtensor1(BaseSubtensor, COp[tuple[TensorVariable], TensorVariable]): """ Increments a subtensor using advanced slicing (list of index). @@ -2269,7 +2269,7 @@ def as_tensor_index_variable(idx): return idx -class AdvancedSubtensor(BaseSubtensor, COp[TensorVariable]): +class AdvancedSubtensor(BaseSubtensor, COp[tuple[TensorVariable], TensorVariable]): """Implements NumPy's advanced indexing.""" __props__ = ("idx_list",) @@ -2550,7 +2550,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): return type(op)(new_idx_list).make_node(batch_x, *batch_idxs) -class AdvancedIncSubtensor(BaseSubtensor, Op[TensorVariable]): +class AdvancedIncSubtensor(BaseSubtensor, Op[tuple[TensorVariable], TensorVariable]): """Increments a subtensor using advanced indexing.""" __props__ = ( diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 028be59be9..075d64a417 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -247,7 +247,7 @@ def transpose(self, *axes): return pt.basic.transpose(self, axes) @property - def shape(self): + def shape(self) -> "TensorVariable": return pt.shape(self) @property diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index b93f28520e..ed830a4c94 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -69,7 +69,7 @@ class TypedListConstant(_typed_list_py_operators, Constant): TypedListType.constant_type = TypedListConstant -class GetItem(COp[Variable]): +class GetItem(COp[tuple[Variable], Variable]): # See doc in instance of this Op or function after this class definition. view_map = {0: [0]} __props__ = () @@ -131,7 +131,7 @@ def c_code_cache_version(self): """ -class Append(COp[TypedListVariable]): +class Append(COp[tuple[TypedListVariable], TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -210,7 +210,7 @@ def c_code_cache_version(self): """ -class Extend(COp[TypedListVariable]): +class Extend(COp[tuple[TypedListVariable], TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -294,7 +294,7 @@ def c_code_cache_version_(self): """ -class Insert(COp[TypedListVariable]): +class Insert(COp[tuple[TypedListVariable], TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -380,7 +380,7 @@ def c_code_cache_version(self): """ -class Remove(Op[TypedListVariable]): +class Remove(Op[tuple[TypedListVariable], TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -437,7 +437,7 @@ def __str__(self): """ -class Reverse(COp[TypedListVariable]): +class Reverse(COp[tuple[TypedListVariable], TypedListVariable]): # See doc in instance of this Op after the class definition. __props__ = ("inplace",) @@ -504,7 +504,7 @@ def c_code_cache_version(self): """ -class Index(Op[ScalarVariable]): +class Index(Op[tuple[ScalarVariable], ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -533,7 +533,7 @@ def __str__(self): index_ = Index() -class Count(Op[ScalarVariable]): +class Count(Op[tuple[ScalarVariable], ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -580,7 +580,7 @@ def __str__(self): """ -class Length(COp[ScalarVariable]): +class Length(COp[tuple[ScalarVariable], ScalarVariable]): # See doc in instance of this Op after the class definition. __props__ = () @@ -621,7 +621,7 @@ def c_code_cache_version(self): """ -class MakeList(Op[TypedListVariable]): +class MakeList(Op[tuple[TypedListVariable], TypedListVariable]): __props__ = () def make_node(self, a): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 07e90819b7..b4dad18f66 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -7,7 +7,7 @@ from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor -class XOp(Op[XTensorVariable]): +class XOp(Op[tuple[XTensorVariable], XTensorVariable]): """A base class for XOps that shouldn't be materialized""" def perform(self, node, inputs, outputs): From be9fb26905128b154b2574d6a84606c3c8b860c9 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 1 Apr 2026 13:03:05 +0200 Subject: [PATCH 3/4] Remove conditional imports of broadcast_shape --- pytensor/tensor/basic.py | 2 +- tests/tensor/random/test_basic.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 145a277f9c..cbd644b238 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -673,7 +673,7 @@ def vectorize_tensor_from_scalar(op, node, batch_x): return identity(batch_x).owner -class ScalarFromTensor(COp[ScalarVariable]): +class ScalarFromTensor(COp[tuple[ScalarVariable], ScalarVariable]): __props__ = () def __call__(self, *args, **kwargs) -> ScalarVariable: diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..364d013a48 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -25,7 +25,6 @@ beta, betabinom, binomial, - broadcast_shapes, categorical, cauchy, chisquare, @@ -78,7 +77,7 @@ def _rvs(*args, size=None, **kwargs): res, size if size is not None - else broadcast_shapes(*[np.shape(a) for a in args]), + else np.broadcast_shapes(*[np.shape(a) for a in args]), ) return res From 2f30ee5d2cd56d86adb5b00bc648badbe5074488 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 1 Apr 2026 13:12:05 +0200 Subject: [PATCH 4/4] Fix mypy failures in tensor/random --- pytensor/tensor/random/basic.py | 29 +++++-------- pytensor/tensor/random/op.py | 77 ++++++++++++++++++++------------- pytensor/tensor/random/utils.py | 55 ++++++++++++++++------- scripts/mypy-failing.txt | 3 -- 4 files changed, 100 insertions(+), 64 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 59e02ee3b4..005b27f7af 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1,9 +1,9 @@ import abc import warnings -from typing import Literal +from types import ModuleType +from typing import Literal, cast import numpy as np -from numpy import broadcast_shapes as np_broadcast_shapes from numpy import einsum as np_einsum from numpy import sqrt as np_sqrt from numpy.linalg import cholesky as np_cholesky @@ -23,16 +23,7 @@ # Scipy.stats is considerably slow to import # We import scipy.stats lazily inside `ScipyRandomVariable` -stats = None - - -try: - broadcast_shapes = np.broadcast_shapes -except AttributeError: - from numpy.lib.stride_tricks import _broadcast_shape - - def broadcast_shapes(*shapes): - return _broadcast_shape(*[np.empty(x, dtype=[]) for x in shapes]) +stats: ModuleType = None # type: ignore[assignment] class ScipyRandomVariable(RandomVariable): @@ -76,7 +67,7 @@ def rng_fn(cls, *args, **kwargs): if size is None: # SciPy will sometimes drop broadcastable dimensions; we need to # check and, if necessary, add them back - exp_shape = broadcast_shapes(*[np.shape(a) for a in args[1:-1]]) + exp_shape = np.broadcast_shape(*[np.shape(a) for a in args[1:-1]]) if res.shape != exp_shape: return np.broadcast_to(res, exp_shape).copy() @@ -622,13 +613,14 @@ class GumbelRV(ScipyRandomVariable): dtype = "floatX" _print_name = ("Gumbel", "\\operatorname{Gumbel}") - def __call__( + # mypy doesn't like the added scale kwarg because it breaks the signature of the parent class. + def __call__( # type: ignore[override] self, loc: np.ndarray | float, scale: np.ndarray | float = 1.0, size: list[int] | int | None = None, **kwargs, - ) -> RandomVariable: + ): r"""Draw samples from a gumbel distribution. Signature @@ -659,7 +651,10 @@ def rng_fn_scipy( scale: np.ndarray | float, size: list[int] | int | None, ) -> np.ndarray: - return stats.gumbel_r.rvs(loc=loc, scale=scale, size=size, random_state=rng) + return cast( + np.ndarray, + stats.gumbel_r.rvs(loc=loc, scale=scale, size=size, random_state=rng), + ) gumbel = GumbelRV() @@ -906,7 +901,7 @@ def __call__(self, mean, cov, size=None, method=None, **kwargs): def rng_fn(self, rng, mean, cov, size): if size is None: - size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + size = np.broadcast_shape(mean.shape[:-1], cov.shape[:-2]) if self.method == "cholesky": A = np_cholesky(cov) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 612a40a76f..c4136fc1c2 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,7 +1,7 @@ import abc import warnings -from collections.abc import Sequence -from typing import Any, Generic, Self, cast +from collections.abc import Callable, Sequence +from typing import Generic, cast import numpy as np @@ -60,7 +60,7 @@ class RandomVariable(RNGConsumerOp): _output_type_depends_on_input_value = True - __props__ = ("name", "signature", "dtype", "inplace") + __props__: tuple[str, ...] = ("name", "signature", "dtype", "inplace") default_output = 1 def __init__( @@ -68,7 +68,7 @@ def __init__( name=None, ndim_supp=None, ndims_params=None, - dtype: str | None = None, + dtype: str | np.dtype | None = None, inplace=None, signature: str | None = None, ): @@ -112,13 +112,13 @@ def __init__( ) if not isinstance(ndims_params, Sequence): raise TypeError("Parameter ndims_params must be sequence type.") - self.ndims_params = tuple(ndims_params) + self.ndims_params: tuple[int, ...] = tuple(ndims_params) self.signature = signature or getattr(self, "signature", None) if self.signature is not None: # Assume a single output. Several methods need to be updated to handle multiple outputs. self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature) - self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig] + self.ndims_params = tuple([len(input_sig) for input_sig in self.inputs_sig]) self.ndim_supp = len(self.output_sig) else: if ( @@ -192,9 +192,11 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): "when signature is not sufficient to infer the support shape" ) - def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray: + def rng_fn( + self, rng: np.random.Generator, *args, **kwargs + ) -> int | float | np.ndarray: """Sample a numeric random variate.""" - return getattr(rng, self.name)(*args, **kwargs) + return getattr(rng, self.name)(*args, **kwargs) # type: ignore[no-any-return] def __str__(self): # Only show signature from core props @@ -241,7 +243,7 @@ def _infer_shape( from pytensor.tensor.extra_ops import broadcast_shape_iter - supp_shape: tuple[Any] + supp_shape: tuple[int | ScalarVariable, ...] if self.ndim_supp == 0: supp_shape = () else: @@ -264,7 +266,9 @@ def _infer_shape( f"Size must be None or have length >= {param_batched_dims}" ) - return tuple(size) + supp_shape + # TODO: This type ignore is because the size tensor is not interpreted as an iterable. + # Once that's fixed, this ignore could be removed. + return (*tuple(size), *supp_shape) # type: ignore[arg-type] # Size was not provided, we must infer it from the shape of the parameters if param_shapes is None: @@ -305,7 +309,7 @@ def extract_batch_shape(p, ps, n): # Distribution has no parameters batch_shape = () - shape = batch_shape + supp_shape + shape = (*batch_shape, *supp_shape) return shape @@ -333,9 +337,14 @@ def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): ) props = self._props_dict() props["dtype"] = dtype - new_op = type(self)(**props) - return new_op.__call__( - *args, size=size, name=name, rng=rng, dtype=dtype, **kwargs + new_op: RandomVariable = type(self)(**props) + return cast( + tuple[RandomGeneratorSharedVariable, TensorVariable] + | TensorVariable + | tuple[TensorVariable], + new_op.__call__( + *args, size=size, name=name, rng=rng, dtype=dtype, **kwargs + ), ) res = super().__call__(rng, size, *args, **kwargs) @@ -385,7 +394,7 @@ def make_node(self, rng, size, *dist_params): inferred_shape = self._infer_shape(size, dist_params) _, static_shape = infer_static_shape(inferred_shape) - dist_params = explicit_expand_dims( + _dist_params = explicit_expand_dims( dist_params, self.ndims_params, size_length=None @@ -393,9 +402,12 @@ def make_node(self, rng, size, *dist_params): else get_vector_length(size), ) - inputs = (rng, size, *dist_params) + inputs = (rng, size, *_dist_params) out_type = TensorType(dtype=self.dtype, shape=static_shape) - outputs = (rng.type(), out_type()) + outputs = cast( + tuple[RandomGeneratorSharedVariable, TensorVariable], + (rng.type(), out_type()), + ) if self.dtype == "floatX": # Commit to a specific float type if the Op is still using "floatX" @@ -404,22 +416,27 @@ def make_node(self, rng, size, *dist_params): props["dtype"] = dtype self = type(self)(**props) - return Apply(self, inputs, outputs) + node: Apply[ + RandomVariable, + tuple[RandomGeneratorSharedVariable, TensorVariable], + TensorVariable, + ] = Apply(self, inputs, outputs) + return node def batch_ndim(self, node: Apply) -> int: return cast(int, node.default_output().type.ndim - self.ndim_supp) - def rng_param(self, node) -> Variable: + def rng_param(self, node) -> RandomGeneratorSharedVariable: """Return the node input corresponding to the rng""" - return node.inputs[0] + return cast(RandomGeneratorSharedVariable, node.inputs[0]) - def size_param(self, node) -> Variable: + def size_param(self, node) -> TensorVariable: """Return the node input corresponding to the size""" - return node.inputs[1] + return cast(TensorVariable, node.inputs[1]) - def dist_params(self, node) -> Sequence[Variable]: + def dist_params(self, node) -> Sequence[TensorVariable]: """Return the node inpust corresponding to dist params""" - return node.inputs[2:] + return tuple(cast(TensorVariable, inp) for inp in node.inputs[2:]) def perform(self, node, inputs, outputs): rng, size, *args = inputs @@ -447,7 +464,9 @@ def R_op(self, inputs, eval_points): class AbstractRNGConstructor(Op, Generic[OpOutputsType, OpDefaultOutputType]): - def make_node(self, seed=None) -> Apply[Self, OpOutputsType, OpDefaultOutputType]: + random_type: Callable[[], OpDefaultOutputType] + + def make_node(self, seed=None): if seed is None: seed = NoneConst elif isinstance(seed, Variable) and isinstance(seed.type, NoneTypeT): @@ -470,7 +489,9 @@ class DefaultGeneratorMakerOp( tuple[RandomGeneratorSharedVariable], RandomGeneratorSharedVariable ] ): - random_type = RandomGeneratorType() + random_type = cast( + Callable[[], RandomGeneratorSharedVariable], RandomGeneratorType() + ) random_constructor = "default_rng" @@ -478,9 +499,7 @@ class DefaultGeneratorMakerOp( @_vectorize_node.register(RandomVariable) -def vectorize_random_variable( - op: RandomVariable, node: Apply, rng, size, *dist_params -) -> Apply: +def vectorize_random_variable(op: RandomVariable, node: Apply, rng, size, *dist_params): # If size was provided originally and a new size hasn't been provided, # We extend it to accommodate the new input batch dimensions. # Otherwise, we assume the new size already has the right values diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 74b70617f6..7e0c624a59 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -3,7 +3,8 @@ from functools import wraps from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload +from typing import cast as typing_cast import numpy as np from numpy.random import Generator @@ -15,6 +16,7 @@ from pytensor.tensor.basic import as_tensor_variable, cast from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to from pytensor.tensor.math import maximum +from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.type import int_dtypes from pytensor.tensor.type_other import NoneTypeT @@ -79,9 +81,30 @@ def max_bcast(x, y): return bcast_shapes +@overload +def broadcast_params( # type: ignore[overload-overlap] + params: Sequence[np.ndarray], ndims_params: Sequence[int] +) -> list[np.ndarray]: ... + + +@overload +def broadcast_params( + params: Sequence[TensorVariable], ndims_params: Sequence[int] +) -> list[TensorVariable]: ... + + +@overload def broadcast_params( params: Sequence[np.ndarray | TensorVariable], ndims_params: Sequence[int] -) -> list[np.ndarray]: +) -> list[TensorVariable]: ... + + +def broadcast_params( + params: Sequence[np.ndarray] + | Sequence[TensorVariable] + | Sequence[np.ndarray | TensorVariable], + ndims_params: Sequence[int], +) -> list[np.ndarray] | list[TensorVariable]: """Broadcast parameters that have different dimensions. >>> ndims_params = [1, 2] @@ -187,23 +210,23 @@ def normalize_size_param( return shape if isinstance(shape, int): - shape = as_tensor_variable([shape], ndim=1) + _shape = as_tensor_variable([shape], ndim=1) else: if not isinstance(shape, Sequence | Variable | np.ndarray): raise TypeError( "Parameter size must be None, an integer, or a sequence with integers." ) - shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") + _shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") - if shape.type.shape == (None,): + if _shape.type.shape == (None,): # This should help ensure that the length of non-constant `size`s # will be available after certain types of cloning (e.g. the kind `Scan` performs) - shape = specify_shape(shape, (get_vector_length(shape),)) + _shape = specify_shape(_shape, (get_vector_length(_shape),)) - assert shape.type.shape != (None,) - assert shape.dtype in int_dtypes + assert _shape.type.shape != (None,) + assert _shape.dtype in int_dtypes - return shape + return _shape def custom_rng_deepcopy(rng): @@ -254,7 +277,9 @@ def __init__( self.namespaces = [(namespace, set(namespace.__all__))] self.default_instance_seed = seed - self.state_updates = [] + self.state_updates: list[ + tuple[RandomGeneratorSharedVariable, RandomGeneratorSharedVariable] + ] = [] self.gen_seedgen = np.random.SeedSequence(seed) self.rng_ctor = rng_ctor @@ -335,7 +360,7 @@ def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable: rng = shared(self.rng_ctor(seed), borrow=True) # Generate the sample - out = op(*args, **kwargs, rng=rng) + out: TensorVariable = op(*args, **kwargs, rng=rng) # This is the value that should be used to replace the old state # (i.e. `rng`) after `out` is sampled/evaluated. @@ -353,7 +378,7 @@ def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable: def supp_shape_from_ref_param_shape( *, ndim_supp: int, - dist_params: Sequence[Variable], + dist_params: Sequence[TensorVariable], param_shapes: Sequence[tuple[ScalarVariable, ...]] | None = None, ref_param_idx: int, ) -> TensorVariable | tuple[ScalarVariable, ...]: @@ -390,8 +415,8 @@ def supp_shape_from_ref_param_shape( if ndim_supp <= 0: raise ValueError("ndim_supp must be greater than 0") if param_shapes is not None: - ref_param = param_shapes[ref_param_idx] - return tuple(ref_param[i] for i in range(-ndim_supp, 0)) + ref_param_shape = param_shapes[ref_param_idx] + return tuple(ref_param_shape[i] for i in range(-ndim_supp, 0)) else: ref_param = dist_params[ref_param_idx] if ref_param.ndim < ndim_supp: @@ -399,4 +424,4 @@ def supp_shape_from_ref_param_shape( "Reference parameter does not match the expected dimensions; " f"{ref_param} has less than {ndim_supp} dim(s)." ) - return ref_param.shape[-ndim_supp:] + return typing_cast(TensorVariable, ref_param.shape[-ndim_supp:]) diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index ff73de2605..778bd55dc9 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -16,9 +16,6 @@ pytensor/tensor/elemwise.py pytensor/tensor/extra_ops.py pytensor/tensor/math.py pytensor/tensor/optimize.py -pytensor/tensor/random/basic.py -pytensor/tensor/random/op.py -pytensor/tensor/random/utils.py pytensor/tensor/rewriting/basic.py pytensor/tensor/type.py pytensor/tensor/type_other.py