diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 1644712704..b145593fc5 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -24,11 +24,12 @@ SequenceDB, TopoDB, ) -from pytensor.link.basic import Linker, PerformLinker +from pytensor.link.basic import Linker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.python.linker import PythonLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -40,7 +41,10 @@ # Mode, it will be used as the key to retrieve the real linker in this # dictionary predefined_linkers = { - "py": PerformLinker(), # Use allow_gc PyTensor flag + "py": VMLinker( # Robust per-node Python VM over python_funcify; handles lazy ops + use_cloop=False, c_thunks=False + ), + "pyjit": PythonLinker(), # Whole-graph python_funcify composition; no lazy ops "c": CLinker(), # Don't support gc. so don't check allow_gc "c|py": OpWiseCLinker(), # Use allow_gc PyTensor flag "c|py_nogc": OpWiseCLinker(allow_gc=False), @@ -476,6 +480,16 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): RewriteDatabaseQuery(include=["fast_run", "mlx"]), ) +PYTHON = Mode( + VMLinker(use_cloop=False, c_thunks=False), + RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"), +) + +PYJIT = Mode( + PythonLinker(), + RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"), +) + FAST_COMPILE = Mode( VMLinker(use_cloop=False, c_thunks=False), RewriteDatabaseQuery(include=["fast_compile", "py_only"]), @@ -495,6 +509,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "NUMBA": NUMBA, "PYTORCH": PYTORCH, "MLX": MLX, + "PYTHON": PYTHON, + "PYJIT": PYJIT, } _CACHED_RUNTIME_MODES: dict[Any, Mode] = {} diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index ebfa067b71..dd1fa42aac 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -454,7 +454,6 @@ def R_op( return [None if isinstance(r.type, DisconnectedType) else r for r in result] # type: ignore[misc] raise NotImplementedError() - @abstractmethod def perform( self, node: Apply, @@ -489,7 +488,21 @@ def perform( An `Op` is free to reuse `output_storage` as it sees fit, or to discard it and allocate new memory. + The default implementation runs the `Op`'s ``python_funcify`` dispatch: + it builds the callable, runs it on the inputs, and writes the results + into ``output_storage``. An `Op` only needs to override this to keep a + bespoke ``perform``; otherwise its numeric behaviour lives entirely in + the ``python_funcify`` registry. """ + from pytensor.link.python.dispatch.basic import python_funcify + + fn = python_funcify(self, node=node) + results = fn(*inputs) + if len(output_storage) == 1: + output_storage[0][0] = results + else: + for output, value in zip(output_storage, results, strict=True): + output[0] = value def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: """Determine whether constant folding should be performed for the given node. diff --git a/pytensor/link/python/__init__.py b/pytensor/link/python/__init__.py new file mode 100644 index 0000000000..105dfba68e --- /dev/null +++ b/pytensor/link/python/__init__.py @@ -0,0 +1 @@ +from pytensor.link.python.linker import PythonLinker diff --git a/pytensor/link/python/dispatch/__init__.py b/pytensor/link/python/dispatch/__init__.py new file mode 100644 index 0000000000..b885f9a876 --- /dev/null +++ b/pytensor/link/python/dispatch/__init__.py @@ -0,0 +1,8 @@ +# isort: off +from pytensor.link.python.dispatch.basic import python_funcify + +# Load dispatch specializations +import pytensor.link.python.dispatch.blockwise +import pytensor.link.python.dispatch.linalg + +# isort: on diff --git a/pytensor/link/python/dispatch/basic.py b/pytensor/link/python/dispatch/basic.py new file mode 100644 index 0000000000..529eaf312a --- /dev/null +++ b/pytensor/link/python/dispatch/basic.py @@ -0,0 +1,99 @@ +from functools import singledispatch + +from pytensor.graph.fg import AbstractFunctionGraph +from pytensor.link.utils import fgraph_to_python + + +@singledispatch +def python_funcify(op, node=None, **kwargs): + """Return a fast pure-Python implementation of ``op`` as a callable. + + The callable takes the node's inputs positionally and returns its output (a + single value, or a tuple for multi-output nodes). Register a specialization + to override an `Op`'s ``perform`` with a faster numpy/scipy path on the + Python backend. + + Unregistered ops raise `NotImplementedError`. The ``py`` (VM) linker catches + this and falls back to ``Op.make_thunk(impl="py")``; the whole-graph ``pyjit`` + linker catches it and falls back to a ``perform`` wrapper. + """ + raise NotImplementedError( + f"No python_funcify implementation registered for {type(op).__name__}" + ) + + +def _perform_wrapper(op, node): + """Wrap an `Op`'s ``perform`` into a `python_funcify`-style callable.""" + n_outputs = len(node.outputs) + single_output = n_outputs == 1 + + def perform(*inputs): + output_storage = [[None] for _ in range(n_outputs)] + op.perform(node, list(inputs), output_storage) + if single_output: + return output_storage[0][0] + return tuple(storage[0] for storage in output_storage) + + return perform + + +def _funcify_or_perform(op, node=None, **kwargs): + try: + return python_funcify(op, node=node) + except NotImplementedError: + return _perform_wrapper(op, node) + + +@python_funcify.register(AbstractFunctionGraph) +def python_funcify_FunctionGraph( + fgraph, node=None, fgraph_name="python_funcified_fgraph", **kwargs +): + return fgraph_to_python( + fgraph, + op_conversion_fn=_funcify_or_perform, + fgraph_name=fgraph_name, + **kwargs, + ) + + +def make_node_thunk_with_python_dispatch( + node, storage_map, compute_map, *, fallback, implementation +): + """Build a per-node thunk, preferring a registered `python_funcify` implementation. + + When `python_funcify` has a specialization for ``node.op``, its callable is + wrapped into a thunk that reads inputs from and writes outputs to + ``storage_map``. Otherwise ``fallback`` (``Op.make_thunk``) is used, which + covers ``perform`` ops and lazy ops like ``IfElse`` unchanged. + """ + try: + fn = python_funcify(node.op, node=node) + except NotImplementedError: + return fallback(node, storage_map, compute_map, implementation) + + return _wrap_callable_as_thunk(fn, node, storage_map, compute_map) + + +def _wrap_callable_as_thunk(fn, node, storage_map, compute_map): + input_storage = [storage_map[variable] for variable in node.inputs] + output_compute = [compute_map[variable] for variable in node.outputs] + + if len(node.outputs) == 1: + [output] = (storage_map[variable] for variable in node.outputs) + + def thunk(fn=fn, inputs=input_storage, output=output, compute=output_compute): + output[0] = fn(*(inp[0] for inp in inputs)) + compute[0][0] = True + else: + output_storage = [storage_map[variable] for variable in node.outputs] + + def thunk( + fn=fn, inputs=input_storage, outputs=output_storage, compute=output_compute + ): + for output, value in zip(outputs, fn(*(inp[0] for inp in inputs))): + output[0] = value + for entry in compute: + entry[0] = True + + thunk.lazy = False + return thunk diff --git a/pytensor/link/python/dispatch/blockwise.py b/pytensor/link/python/dispatch/blockwise.py new file mode 100644 index 0000000000..354cb548bb --- /dev/null +++ b/pytensor/link/python/dispatch/blockwise.py @@ -0,0 +1,17 @@ +import numpy as np + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.blockwise import Blockwise + + +@python_funcify.register(Blockwise) +def python_funcify_Blockwise(op, node=None, **kwargs): + core_node = op._create_dummy_core_node( + node.inputs, propagate_unbatched_core_inputs=True + ) + # Raises NotImplementedError when the core Op has no dispatch, which makes the + # whole Blockwise fall back to its (vectorized) perform. + core_fn = python_funcify(op.core_op, node=core_node) + + output_dtypes = [output.type.dtype for output in node.outputs] + return np.vectorize(core_fn, signature=op.signature, otypes=output_dtypes) diff --git a/pytensor/link/python/dispatch/linalg/__init__.py b/pytensor/link/python/dispatch/linalg/__init__.py new file mode 100644 index 0000000000..f5165bc0f1 --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/__init__.py @@ -0,0 +1,4 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.decomposition +import pytensor.link.python.dispatch.linalg.solvers +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/decomposition/__init__.py b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py new file mode 100644 index 0000000000..c1bfeab072 --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py @@ -0,0 +1,4 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.decomposition.cholesky +import pytensor.link.python.dispatch.linalg.decomposition.qr +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py b/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py new file mode 100644 index 0000000000..9042f88336 --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py @@ -0,0 +1,32 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.decomposition.cholesky import Cholesky + + +@python_funcify.register(Cholesky) +def python_funcify_Cholesky(op, node=None, **kwargs): + lower = op.lower + overwrite_a = op.overwrite_a + (potrf,) = get_lapack_funcs(("potrf",), dtype=node.inputs[0].type.dtype) + + def cholesky(x): + if x.size == 0: + return np.empty_like(x) + + # potrf only honors overwrite_a for F-contiguous input; transpose a + # C-contiguous array to benefit from it. + c_contiguous_input = overwrite_a and x.flags["C_CONTIGUOUS"] + if c_contiguous_input: + x = x.T + factor, info = potrf(x, lower=not lower, overwrite_a=True, clean=True) + factor = factor.T + else: + factor, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) + + if info != 0: + factor[...] = np.nan + return factor + + return cholesky diff --git a/pytensor/link/python/dispatch/linalg/decomposition/qr.py b/pytensor/link/python/dispatch/linalg/decomposition/qr.py new file mode 100644 index 0000000000..903a3b6c9b --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/qr.py @@ -0,0 +1,54 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.decomposition.qr import QR + + +@python_funcify.register(QR) +def python_funcify_QR(op, node=None, **kwargs): + mode = op.mode + pivoting = op.pivoting + overwrite_a = op.overwrite_a + call_and_get_lwork = op._call_and_get_lwork + + def qr(x): + M, N = x.shape + + if pivoting: + (geqp3,) = get_lapack_funcs(("geqp3",), (x,)) + factor, jpvt, tau, *_ = call_and_get_lwork( + geqp3, x, lwork=-1, overwrite_a=overwrite_a + ) + jpvt -= 1 # geqp3 returns 1-based indices + else: + (geqrf,) = get_lapack_funcs(("geqrf",), (x,)) + factor, tau, *_ = call_and_get_lwork( + geqrf, x, lwork=-1, overwrite_a=overwrite_a + ) + + if mode not in ("economic", "raw") or M < N: + R = np.triu(factor) + else: + R = np.triu(factor[:N, :]) + + if mode == "r": + return (R, jpvt) if pivoting else R + if mode == "raw": + return (factor, tau, R, jpvt) if pivoting else (factor, tau, R) + + (orgqr,) = get_lapack_funcs(("orgqr",), (factor,)) + if M < N: + Q, *_ = call_and_get_lwork( + orgqr, factor[:, :M], tau, lwork=-1, overwrite_a=1 + ) + elif mode == "economic": + Q, *_ = call_and_get_lwork(orgqr, factor, tau, lwork=-1, overwrite_a=1) + else: + square = np.empty((M, M), dtype=factor.dtype.char) + square[:, :N] = factor + Q, *_ = call_and_get_lwork(orgqr, square, tau, lwork=-1, overwrite_a=1) + + return (Q, R, jpvt) if pivoting else (Q, R) + + return qr diff --git a/pytensor/link/python/dispatch/linalg/solvers/__init__.py b/pytensor/link/python/dispatch/linalg/solvers/__init__.py new file mode 100644 index 0000000000..d2a7665d4b --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/solvers/__init__.py @@ -0,0 +1,3 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.solvers.triangular +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/solvers/triangular.py b/pytensor/link/python/dispatch/linalg/solvers/triangular.py new file mode 100644 index 0000000000..6b9ac2c4bc --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/solvers/triangular.py @@ -0,0 +1,43 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.solvers.triangular import SolveTriangular + + +@python_funcify.register(SolveTriangular) +def python_funcify_SolveTriangular(op, node=None, **kwargs): + lower = op.lower + unit_diagonal = op.unit_diagonal + overwrite_b = op.overwrite_b + (trtrs,) = get_lapack_funcs(("trtrs",), dtype=node.outputs[0].type.dtype) + + def solve_triangular(A, b): + if b.size == 0: + return np.empty_like(b) + + if A.flags["F_CONTIGUOUS"]: + x, info = trtrs( + A, + b, + overwrite_b=overwrite_b, + lower=lower, + trans=0, + unitdiag=unit_diagonal, + ) + else: + # trtrs expects Fortran ordering, so solve the transposed system. + x, info = trtrs( + A.T, + b, + overwrite_b=overwrite_b, + lower=not lower, + trans=1, + unitdiag=unit_diagonal, + ) + + if info != 0: + x[...] = np.nan + return x + + return solve_triangular diff --git a/pytensor/link/python/linker.py b/pytensor/link/python/linker.py new file mode 100644 index 0000000000..e50d64b72f --- /dev/null +++ b/pytensor/link/python/linker.py @@ -0,0 +1,25 @@ +from pytensor.link.basic import JITLinker + + +class PythonLinker(JITLinker): + """Compose a `FunctionGraph` into a single pure-Python function. + + The whole graph is turned into one straight-line Python function by + `fgraph_to_python`, dispatching each `Op` through the `python_funcify` + registry (falling back to ``perform`` for unregistered ops). There is no + compilation step, so `jit_compile` is the identity. + """ + + required_rewrites = ("minimum_compile", "py_only") + incompatible_rewrites = ("cxx_only",) + + def fgraph_convert(self, fgraph, **kwargs): + from pytensor.link.python.dispatch.basic import python_funcify + + return python_funcify(fgraph, **kwargs) + + def jit_compile(self, fn): + return fn + + def create_thunk_inputs(self, storage_map): + return [storage_map[n] for n in self.fgraph.inputs] diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 239f73df80..ced605600e 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -1203,6 +1203,35 @@ def make_vm( ) return vm + def _make_node_thunk(self, node, storage_map, compute_map, implementation): + """Create the thunk for a single node. + + On the pure-Python path (``implementation == "py"``) a registered + ``python_funcify`` implementation is wrapped into a thunk; otherwise the + node's ``Op.make_thunk`` is used (which also covers lazy ops like + ``IfElse``). + """ + if implementation == "py": + from pytensor.link.python.dispatch.basic import ( + make_node_thunk_with_python_dispatch, + ) + + return make_node_thunk_with_python_dispatch( + node, + storage_map, + compute_map, + fallback=self._make_perform_thunk, + implementation=implementation, + ) + return self._make_perform_thunk(node, storage_map, compute_map, implementation) + + def _make_perform_thunk(self, node, storage_map, compute_map, implementation): + # no-recycling is done at each VM.__call__, so there is no need to cause + # duplicate C code by passing no_recycling here. + return node.op.make_thunk( + node, storage_map, compute_map, [], impl=implementation + ) + def make_all( self, profiler=None, @@ -1224,17 +1253,16 @@ def make_all( t0 = time.perf_counter() linker_make_thunk_time = {} - impl = None + implementation = None if self.c_thunks is False: - impl = "py" + implementation = "py" for node in order: try: thunk_start = time.perf_counter() - # no-recycling is done at each VM.__call__ So there is - # no need to cause duplicate c code by passing - # no_recycling here. thunks.append( - node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) + self._make_node_thunk( + node, storage_map, compute_map, implementation + ) ) linker_make_thunk_time[node] = time.perf_counter() - thunk_start if not hasattr(thunks[-1], "lazy"): diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 0fa7a34b3b..95babfb6e9 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -10,7 +10,6 @@ from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.linalg._lazy import scipy_linalg from pytensor.tensor.linalg.dtype_utils import linalg_output_dtype from pytensor.tensor.type import tensor @@ -45,43 +44,6 @@ def make_node(self, x): dtype = linalg_output_dtype(x.type.dtype) return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) - def perform(self, node, inputs, outputs): - [x] = inputs - [out] = outputs - - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) - - # Quick return for square empty array - if x.size == 0: - out[0] = np.empty_like(x, dtype=potrf.dtype) - return - - # Squareness check - if x.shape[0] != x.shape[1]: - raise ValueError( - f"Input array is expected to be square but has the shape: {x.shape}." - ) - - # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS - # If we have a `C_CONTIGUOUS` array we transpose to benefit from it - c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] - if c_contiguous_input: - x = x.T - lower = not self.lower - overwrite_a = True - else: - lower = self.lower - overwrite_a = self.overwrite_a - - c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) - - if info != 0: - c[...] = np.nan - out[0] = c - else: - # Transpose result if input was transposed - out[0] = c.T if c_contiguous_input else c - def pullback(self, inputs, outputs, gradients): """ Cholesky decomposition reverse-mode gradient update. diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 9e9270259d..5ac6f5626f 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -1,7 +1,5 @@ from typing import Literal, cast -import numpy as np - from pytensor import ifelse from pytensor import tensor as pt from pytensor.gradient import DisconnectedType, disconnected_type @@ -51,7 +49,7 @@ def __init__( case "r": self.gufunc_signature = "(m,n)->(m,n)" case "raw": - self.gufunc_signature = "(m,n)->(n,m),(k),(m,n)" + self.gufunc_signature = "(m,n)->(m,n),(k),(k,n)" case _: raise ValueError( f"Invalid mode '{mode}'. Supported modes are 'full', 'economic', 'r', and 'raw'." @@ -91,9 +89,9 @@ def make_node(self, x): ] case "raw": outputs = [ - tensor(shape=(M, M), dtype=out_dtype), - tensor(shape=(K,), dtype=out_dtype), tensor(shape=(M, N), dtype=out_dtype), + tensor(shape=(K,), dtype=out_dtype), + tensor(shape=(K, N), dtype=out_dtype), ] case _: raise NotImplementedError @@ -124,9 +122,9 @@ def infer_shape(self, fgraph, node, shapes): case "r": R_shape = (M, N) case "raw": - Q_shape = (M, M) # Actually this is H in this case + Q_shape = (M, N) # Actually this is H in this case tau_shape = (K,) - R_shape = (M, N) + R_shape = (K, N) if self.pivoting: P_shape = (N,) @@ -153,72 +151,17 @@ def _call_and_get_lwork(self, fn, *args, lwork, **kwargs): def perform(self, node, inputs, outputs): (x,) = inputs - M, N = x.shape - - if self.pivoting: - (geqp3,) = scipy_linalg.get_lapack_funcs(("geqp3",), (x,)) - qr, jpvt, tau, *_work_info = self._call_and_get_lwork( - geqp3, x, lwork=-1, overwrite_a=self.overwrite_a - ) - jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1 - else: - (geqrf,) = scipy_linalg.get_lapack_funcs(("geqrf",), (x,)) - qr, tau, *_work_info = self._call_and_get_lwork( - geqrf, x, lwork=-1, overwrite_a=self.overwrite_a - ) - - if self.mode not in ["economic", "raw"] or M < N: - R = np.triu(qr) - else: - R = np.triu(qr[:N, :]) - - if self.mode == "r" and self.pivoting: - outputs[0][0] = R - outputs[1][0] = jpvt - return - - elif self.mode == "r": - outputs[0][0] = R - return - - elif self.mode == "raw" and self.pivoting: - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R - outputs[3][0] = jpvt - return - - elif self.mode == "raw": - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R - return - - (gor_un_gqr,) = scipy_linalg.get_lapack_funcs(("orgqr",), (qr,)) - - if M < N: - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qr[:, :M], tau, lwork=-1, overwrite_a=1 - ) - elif self.mode == "economic": - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qr, tau, lwork=-1, overwrite_a=1 - ) + result = scipy_linalg.qr( + x, mode=self.mode, pivoting=self.pivoting, overwrite_a=self.overwrite_a + ) + if self.mode == "raw": + # scipy nests the Householder reflectors as ((qr, tau), R[, jpvt]). + (factor, tau), *rest = result + values = [factor, tau, *rest] else: - t = qr.dtype.char - qqr = np.empty((M, M), dtype=t) - qqr[:, :N] = qr - - # Always overwite qqr -- it's a meaningless intermediate value - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qqr, tau, lwork=-1, overwrite_a=1 - ) - - outputs[0][0] = Q - outputs[1][0] = R - - if self.pivoting: - outputs[2][0] = jpvt + values = list(result) + for storage, value in zip(outputs, values): + storage[0] = value def pullback(self, inputs, outputs, output_grads): """ diff --git a/pytensor/tensor/linalg/solvers/triangular.py b/pytensor/tensor/linalg/solvers/triangular.py index e3adee0612..9d8eb1172c 100644 --- a/pytensor/tensor/linalg/solvers/triangular.py +++ b/pytensor/tensor/linalg/solvers/triangular.py @@ -31,44 +31,17 @@ def __init__(self, *, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - - if A.ndim != 2 or A.shape[0] != A.shape[1]: - raise ValueError("expected square matrix") - - if A.shape[0] != b.shape[0]: - raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") - - (trtrs,) = scipy_linalg.get_lapack_funcs(("trtrs",), (A, b)) - - # Quick return for empty arrays - if b.size == 0: - outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) - return - - if A.flags["F_CONTIGUOUS"]: - x, info = trtrs( + try: + outputs[0][0] = scipy_linalg.solve_triangular( A, b, - overwrite_b=self.overwrite_b, lower=self.lower, - trans=0, - unitdiag=self.unit_diagonal, - ) - else: - # transposed system is solved since trtrs expects Fortran ordering - x, info = trtrs( - A.T, - b, + unit_diagonal=self.unit_diagonal, overwrite_b=self.overwrite_b, - lower=not self.lower, - trans=1, - unitdiag=self.unit_diagonal, + check_finite=False, ) - - if info != 0: - x[...] = np.nan - - outputs[0][0] = x + except scipy_linalg.LinAlgError: + outputs[0][0] = np.full(b.shape, np.nan, dtype=node.outputs[0].type.dtype) def pullback(self, inputs, outputs, output_gradients): res = super().pullback(inputs, outputs, output_gradients) diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index a4a699fe21..ddcf5c0d7d 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -99,11 +99,10 @@ def test_modes(self): # regression check: # there should be # - NumbaLinker - # - `VMLinker` - # - OpWiseCLinker (FAST_RUN) - # - PerformLinker (FAST_COMPILE) + # - `VMLinker` (py, vm, cvm, FAST_COMPILE) + # - OpWiseCLinker (c|py) # - DebugMode's Linker (DEBUG_MODE) - assert 5 == len(set(linker_classes_involved)) + assert 4 == len(set(linker_classes_involved)) class TestOldModesProblem: diff --git a/tests/link/python/__init__.py b/tests/link/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/python/test_basic.py b/tests/link/python/test_basic.py new file mode 100644 index 0000000000..4826f03d67 --- /dev/null +++ b/tests/link/python/test_basic.py @@ -0,0 +1,234 @@ +import subprocess +import sys + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.compile.mode import get_mode, predefined_linkers +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.ifelse import ifelse +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.link.python.linker import PythonLinker +from pytensor.link.vm import VMLinker +from pytensor.raise_op import Assert +from pytensor.scalar.basic import Composite +from pytensor.tensor.blas import Gemv +from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.type import matrix, vector + + +def python_function(inputs, outputs, **kwargs): + return pytensor.function(inputs, outputs, mode="PYTHON", **kwargs) + + +def compare_py_and_pyjit(graph_inputs, graph_outputs, test_inputs, assert_fn=None): + """Compare the per-node ``py`` (VM) backend against the whole-graph ``pyjit``. + + Both run the same ``python_funcify`` dispatches, so this checks that the VM + per-node wiring and the JIT whole-graph composition agree. Only valid for + graphs without lazy ops, which the JIT cannot compose. + """ + if assert_fn is None: + + def assert_fn(a, b): + np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-10) + + py_fn = pytensor.function(graph_inputs, graph_outputs, mode="PYTHON") + pyjit_fn = pytensor.function(graph_inputs, graph_outputs, mode="PYJIT") + + py_res = py_fn(*test_inputs) + pyjit_res = pyjit_fn(*test_inputs) + for py_out, pyjit_out in zip( + py_res if isinstance(py_res, list) else [py_res], + pyjit_res if isinstance(pyjit_res, list) else [pyjit_res], + ): + assert_fn(py_out, pyjit_out) + return py_fn, py_res + + +def test_mode_and_linker_registered(): + # "py" is the robust per-node VM backend; "pyjit" the whole-graph JIT. + assert isinstance(predefined_linkers["py"], VMLinker) + assert isinstance(predefined_linkers["pyjit"], PythonLinker) + assert isinstance(get_mode("PYTHON").linker, VMLinker) + assert isinstance(get_mode("PYJIT").linker, PythonLinker) + + +@pytest.mark.parametrize( + "build, values, expected", + [ + pytest.param( + lambda x, y: pt.exp(x) + y * 2.0, + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: np.exp(xv) + yv * 2.0, + id="elemwise", + ), + pytest.param( + lambda x, y: x.sum() + y.mean(), + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: xv.sum() + yv.mean(), + id="reduction", + ), + pytest.param( + lambda x, y: x[1:3] - y[::-1][1:3], + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: xv[1:3] - yv[::-1][1:3], + id="subtensor", + ), + ], +) +def test_vector_graphs(build, values, expected): + x = vector("x") + y = vector("y") + fn = python_function([x, y], build(x, y)) + np.testing.assert_allclose(fn(*values), expected(*values)) + + +def test_matmul(): + A = matrix("A") + B = matrix("B") + Av = np.arange(6.0).reshape(2, 3) + Bv = np.arange(6.0).reshape(3, 2) + fn = python_function([A, B], A @ B) + np.testing.assert_allclose(fn(Av, Bv), Av @ Bv) + + +def test_multiple_outputs(): + x = vector("x") + y = vector("y") + fn = python_function([x, y], [x + y, x - y]) + xv, yv = np.arange(4.0), np.arange(4.0) + 1 + out_add, out_sub = fn(xv, yv) + np.testing.assert_allclose(out_add, xv + yv) + np.testing.assert_allclose(out_sub, xv - yv) + + +def test_constant_in_graph(): + x = vector("x") + fn = python_function([x], x + pt.constant(np.ones(4))) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv + 1.0) + + +def test_constant_only_output(): + # An output with no owner (a bare constant) must still be returned. + fn = python_function([], pt.constant(5.0)) + np.testing.assert_allclose(fn(), 5.0) + + +def test_shared_input(): + x = vector("x") + s = pytensor.shared(2.0) + fn = python_function([x], x * s) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv * 2.0) + s.set_value(3.0) + np.testing.assert_allclose(fn(xv), xv * 3.0) + + +def test_no_outputs(): + x = vector("x") + fn = python_function([x], [], on_unused_input="ignore") + assert fn(np.arange(4.0)) == [] + + +def test_fusion_excluded(): + x = vector("x") + y = vector("y") + fn = python_function([x, y], pt.exp(x) * y + pt.log(x) - y**2) + elemwise_nodes = [ + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise) + ] + # Without fusion every scalar op stays its own vectorized Elemwise node, + # rather than collapsing into a single Composite. + assert len(elemwise_nodes) > 1 + assert not any(isinstance(node.op.scalar_op, Composite) for node in elemwise_nodes) + + +def test_cxx_only_excluded(): + # The `use_c_blas` rewrite (tagged cxx_only) would turn Gemv into the + # C-only CGemv, which has no perform and would strand a pure-Python linker. + # Excluding cxx_only keeps the perform-backed Gemv. + A = matrix("A") + x = vector("x") + y = vector("y") + fn = python_function([A, x, y], 2.0 * y + 3.0 * (A @ x)) + ops = {type(node.op) for node in fn.maker.fgraph.apply_nodes} + assert Gemv in ops + assert CGemv not in ops + Av, xv, yv = np.arange(6.0).reshape(2, 3), np.arange(3.0), np.arange(2.0) + np.testing.assert_allclose(fn(Av, xv, yv), 2.0 * yv + 3.0 * (Av @ xv)) + + +def test_ifelse_lazy(): + # IfElse has no perform (only a lazy make_thunk). The py (VM) backend runs it + # via the fallback AND short-circuits it: the unused branch (which raises if + # evaluated) must not run. The whole-graph pyjit backend cannot do this. + c = pt.scalar("c") + x = vector("x") + boom = Assert("unused branch must not run")(x, pt.eq(x.sum(), -999.0)) + fn = python_function([c, x], ifelse(c > 0, x * 2.0, boom)) + np.testing.assert_allclose(fn(1.0, np.ones(3)), np.full(3, 2.0)) + + +class _PerformOnlyOp(Op): + __props__ = () + + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.square(inputs[0]) + + +def test_default_dispatch_uses_perform(): + x = vector("x") + fn = python_function([x], _PerformOnlyOp()(x)) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv**2) + + +class _DispatchedOp(Op): + __props__ = () + + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + +@python_funcify.register(_DispatchedOp) +def _python_funcify_dispatched_op(op, node=None, **kwargs): + # A fast path distinguishable from the identity perform above. + def impl(x): + return x + 100.0 + + return impl + + +def test_registered_dispatch_overrides_perform(): + x = vector("x") + fn = python_function([x], _DispatchedOp()(x)) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv + 100.0) + + +def test_dispatch_loaded_lazily(): + # Importing pytensor must not pull in the dispatch package; it should only + # load on the first PYTHON compile. + script = ( + "import sys, pytensor, pytensor.tensor as pt;" + "mod='pytensor.link.python.dispatch.basic';" + "assert mod not in sys.modules, 'loaded too early';" + "x=pt.vector('x');" + "pytensor.function([x], x+1, mode='PYTHON');" + "assert mod in sys.modules, 'not loaded after compile'" + ) + subprocess.run([sys.executable, "-c", script], check=True) diff --git a/tests/link/python/test_blockwise.py b/tests/link/python/test_blockwise.py new file mode 100644 index 0000000000..74c2259cbf --- /dev/null +++ b/tests/link/python/test_blockwise.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.linalg.decomposition.cholesky import Cholesky +from tests.link.python.test_basic import compare_py_and_pyjit + + +def _has_cholesky(fgraph): + # The useless-Blockwise rewrite leaves a bare Cholesky for unbatched inputs. + for node in fgraph.apply_nodes: + op = node.op + if isinstance(op, Cholesky) or ( + isinstance(op, Blockwise) and isinstance(op.core_op, Cholesky) + ): + return True + return False + + +def _pd_matrices(shape, seed=0): + rng = np.random.default_rng(seed) + B = rng.standard_normal(shape) + A = B @ np.swapaxes(B, -1, -2) + shape[-1] * np.eye(shape[-1]) + return A.astype("float64") + + +@pytest.mark.parametrize("lower", [True, False]) +@pytest.mark.parametrize("shape", [(3, 3), (4, 3, 3), (2, 5, 3, 3)]) +def test_cholesky_dispatch(shape, lower): + A = pt.tensor("A", shape=(None,) * len(shape)) + out = pt.linalg.cholesky(A, lower=lower) + fn, _ = compare_py_and_pyjit([A], out, [_pd_matrices(shape)]) + assert _has_cholesky(fn.maker.fgraph) + + +@pytest.mark.parametrize("lower", [True, False]) +@pytest.mark.parametrize("b_shape", [(4,), (4, 2)]) +@pytest.mark.parametrize("batch", [(), (3,)]) +def test_solve_triangular_dispatch(batch, b_shape, lower): + rng = np.random.default_rng(1) + Av = np.tril(rng.standard_normal((*batch, 4, 4))) + 4 * np.eye(4) + if not lower: + Av = np.swapaxes(Av, -1, -2) + bv = rng.standard_normal((*batch, *b_shape)) + A = pt.tensor("A", shape=(None,) * Av.ndim) + b = pt.tensor("b", shape=(None,) * bv.ndim) + out = pt.linalg.solve_triangular(A, b, lower=lower, b_ndim=len(b_shape)) + compare_py_and_pyjit([A, b], out, [Av.astype("float64"), bv.astype("float64")]) + + +@pytest.mark.parametrize("pivoting", [False, True]) +@pytest.mark.parametrize("mode", ["full", "economic", "r", "raw"]) +@pytest.mark.parametrize("shape", [(4, 3), (3, 4), (5, 3, 4)]) +def test_qr_dispatch(shape, mode, pivoting): + rng = np.random.default_rng(0) + A = pt.tensor("A", shape=(None,) * len(shape)) + out = pt.linalg.qr(A, mode=mode, pivoting=pivoting) + compare_py_and_pyjit([A], out, [rng.standard_normal(shape)]) + + +def test_blockwise_falls_back_without_core_dispatch(): + # The general Solve has no python_funcify dispatch, so Blockwise must fall + # back to its (vectorized) perform and still match the reference. + A = pt.matrix("A") + b = pt.vector("b") + out = pt.linalg.solve(A, b) + rng = np.random.default_rng(2) + Av = rng.standard_normal((4, 4)) + 4 * np.eye(4) + bv = rng.standard_normal(4) + compare_py_and_pyjit([A, b], out, [Av, bv])