From 71d722b236e1c21f1e181acd58fc7fae1a02a1aa Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 00:21:25 -0500 Subject: [PATCH 01/10] Add overridable per-node thunk hook to VMLinker --- pytensor/link/vm.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 239f73df80..4a53c7c216 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -1203,6 +1203,16 @@ def make_vm( ) return vm + def _make_node_thunk(self, node, storage_map, compute_map, impl): + """Create the thunk for a single node. + + Subclasses override this to intercept thunk creation (e.g. to consult a + dispatch registry) before falling back to ``Op.make_thunk``. + """ + # no-recycling is done at each VM.__call__, so there is no need to cause + # duplicate C code by passing no_recycling here. + return node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) + def make_all( self, profiler=None, @@ -1230,11 +1240,8 @@ def make_all( for node in order: try: thunk_start = time.perf_counter() - # no-recycling is done at each VM.__call__ So there is - # no need to cause duplicate c code by passing - # no_recycling here. thunks.append( - node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) + self._make_node_thunk(node, storage_map, compute_map, impl) ) linker_make_thunk_time[node] = time.perf_counter() - thunk_start if not hasattr(thunks[-1], "lazy"): From bbdd3495f50a9edf9ae2dfc5a97c7bec76bde1f8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 00:21:33 -0500 Subject: [PATCH 02/10] Add pure-Python backend (PythonLinker + python_funcify) --- pytensor/compile/mode.py | 8 +++ pytensor/link/python/__init__.py | 1 + pytensor/link/python/dispatch/__init__.py | 6 +++ pytensor/link/python/dispatch/basic.py | 59 +++++++++++++++++++++++ pytensor/link/python/linker.py | 52 ++++++++++++++++++++ 5 files changed, 126 insertions(+) create mode 100644 pytensor/link/python/__init__.py create mode 100644 pytensor/link/python/dispatch/__init__.py create mode 100644 pytensor/link/python/dispatch/basic.py create mode 100644 pytensor/link/python/linker.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 1644712704..7d1538f4d7 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -29,6 +29,7 @@ from pytensor.link.jax.linker import JAXLinker from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.python.linker import PythonLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -52,6 +53,7 @@ "pytorch": PytorchLinker(), "numba": NumbaLinker(), "mlx": MLXLinker(), + "python": PythonLinker(), } @@ -476,6 +478,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): RewriteDatabaseQuery(include=["fast_run", "mlx"]), ) +PYTHON = Mode( + PythonLinker(), + RewriteDatabaseQuery(include=["fast_run"]), +) + FAST_COMPILE = Mode( VMLinker(use_cloop=False, c_thunks=False), RewriteDatabaseQuery(include=["fast_compile", "py_only"]), @@ -495,6 +502,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "NUMBA": NUMBA, "PYTORCH": PYTORCH, "MLX": MLX, + "PYTHON": PYTHON, } _CACHED_RUNTIME_MODES: dict[Any, Mode] = {} diff --git a/pytensor/link/python/__init__.py b/pytensor/link/python/__init__.py new file mode 100644 index 0000000000..105dfba68e --- /dev/null +++ b/pytensor/link/python/__init__.py @@ -0,0 +1 @@ +from pytensor.link.python.linker import PythonLinker diff --git a/pytensor/link/python/dispatch/__init__.py b/pytensor/link/python/dispatch/__init__.py new file mode 100644 index 0000000000..45121ea5b8 --- /dev/null +++ b/pytensor/link/python/dispatch/__init__.py @@ -0,0 +1,6 @@ +# isort: off +from pytensor.link.python.dispatch.basic import python_funcify + +# Load dispatch specializations +# (none yet — per-family override modules are imported here as they are added) +# isort: on diff --git a/pytensor/link/python/dispatch/basic.py b/pytensor/link/python/dispatch/basic.py new file mode 100644 index 0000000000..92a39bcc09 --- /dev/null +++ b/pytensor/link/python/dispatch/basic.py @@ -0,0 +1,59 @@ +from functools import singledispatch + + +@singledispatch +def python_funcify(op, node=None, **kwargs): + """Return a fast pure-Python implementation of ``op`` as a callable. + + The callable takes the node's inputs positionally and returns its output (a + single value, or a tuple for multi-output nodes). Register a specialization + to override an `Op`'s ``perform`` with a faster numpy/scipy path on the + Python backend. + + Unregistered ops raise `NotImplementedError`, signalling the linker to fall + back to ``perform`` (via ``Op.make_thunk(impl="py")``). + """ + raise NotImplementedError( + f"No python_funcify implementation registered for {type(op).__name__}" + ) + + +def make_node_thunk_with_python_dispatch( + node, storage_map, compute_map, *, fallback, impl +): + """Build a per-node thunk, preferring a registered `python_funcify` impl. + + When `python_funcify` has a specialization for ``node.op``, its callable is + wrapped into a thunk that reads inputs from and writes outputs to + ``storage_map``. Otherwise ``fallback`` (``Op.make_thunk``) is used, which + covers ``perform`` ops and lazy ops like ``IfElse`` unchanged. + """ + try: + fn = python_funcify(node.op, node=node) + except NotImplementedError: + return fallback(node, storage_map, compute_map, impl) + + return _wrap_callable_as_thunk(fn, node, storage_map, compute_map) + + +def _wrap_callable_as_thunk(fn, node, storage_map, compute_map): + input_storage = [storage_map[v] for v in node.inputs] + output_compute = [compute_map[v] for v in node.outputs] + + if len(node.outputs) == 1: + [out_storage] = (storage_map[v] for v in node.outputs) + + def thunk(fn=fn, inputs=input_storage, out=out_storage, cm=output_compute): + out[0] = fn(*(inp[0] for inp in inputs)) + cm[0][0] = True + else: + output_storage = [storage_map[v] for v in node.outputs] + + def thunk(fn=fn, inputs=input_storage, outs=output_storage, cm=output_compute): + for storage, value in zip(outs, fn(*(inp[0] for inp in inputs))): + storage[0] = value + for entry in cm: + entry[0] = True + + thunk.lazy = False + return thunk diff --git a/pytensor/link/python/linker.py b/pytensor/link/python/linker.py new file mode 100644 index 0000000000..3f8db83c42 --- /dev/null +++ b/pytensor/link/python/linker.py @@ -0,0 +1,52 @@ +from pytensor.link.vm import VMLinker + + +class PythonLinker(VMLinker): + """A pure-Python `VMLinker` that runs each node through the `python_funcify` registry. + + Per node, a registered `python_funcify` implementation (a fast numpy/scipy + callable) is wrapped into a thunk; unregistered ops fall back to their + ``perform`` method via ``Op.make_thunk(impl="py")``. Lazy ops such as + ``IfElse`` fall through to their own thunks, so the VM still short-circuits + them. Fusion is excluded because fused ``Composite`` loops run slower than + vectorized numpy on this backend. + """ + + def __init__( + self, + allow_gc=None, + use_cloop=False, + callback=None, + callback_input=None, + lazy=None, + schedule=None, + c_thunks=None, + allow_partial_eval=None, + ): + # The Python backend never emits C: per-node Python thunks, Python VM. + super().__init__( + allow_gc=allow_gc, + use_cloop=False, + callback=callback, + callback_input=callback_input, + lazy=lazy, + schedule=schedule, + c_thunks=False, + allow_partial_eval=allow_partial_eval, + ) + # ``c_thunks=False`` already gives ("minimum_compile", "py_only") / + # ("cxx_only",); add fusion for the numpy backend. + self.incompatible_rewrites = ("cxx_only", "fusion") + + def _make_node_thunk(self, node, storage_map, compute_map, impl): + from pytensor.link.python.dispatch.basic import ( + make_node_thunk_with_python_dispatch, + ) + + return make_node_thunk_with_python_dispatch( + node, + storage_map, + compute_map, + fallback=super()._make_node_thunk, + impl=impl, + ) From 4b662510fd6b06d2bf2290b94e81622339e16fed Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 00:21:38 -0500 Subject: [PATCH 03/10] Add tests for Python backend --- tests/link/python/__init__.py | 0 tests/link/python/test_basic.py | 205 ++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 tests/link/python/__init__.py create mode 100644 tests/link/python/test_basic.py diff --git a/tests/link/python/__init__.py b/tests/link/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/python/test_basic.py b/tests/link/python/test_basic.py new file mode 100644 index 0000000000..8c728d8f11 --- /dev/null +++ b/tests/link/python/test_basic.py @@ -0,0 +1,205 @@ +import subprocess +import sys + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.compile.mode import get_mode, predefined_linkers +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.ifelse import ifelse +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.link.python.linker import PythonLinker +from pytensor.raise_op import Assert +from pytensor.scalar.basic import Composite +from pytensor.tensor.blas import Gemv +from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.type import matrix, vector + + +def python_function(inputs, outputs, **kwargs): + return pytensor.function(inputs, outputs, mode="PYTHON", **kwargs) + + +def test_mode_and_linker_registered(): + assert isinstance(predefined_linkers["python"], PythonLinker) + assert isinstance(get_mode("PYTHON").linker, PythonLinker) + + +@pytest.mark.parametrize( + "build, values, expected", + [ + pytest.param( + lambda x, y: pt.exp(x) + y * 2.0, + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: np.exp(xv) + yv * 2.0, + id="elemwise", + ), + pytest.param( + lambda x, y: x.sum() + y.mean(), + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: xv.sum() + yv.mean(), + id="reduction", + ), + pytest.param( + lambda x, y: x[1:3] - y[::-1][1:3], + (np.arange(4.0), np.arange(4.0) + 1), + lambda xv, yv: xv[1:3] - yv[::-1][1:3], + id="subtensor", + ), + ], +) +def test_vector_graphs(build, values, expected): + x = vector("x") + y = vector("y") + fn = python_function([x, y], build(x, y)) + np.testing.assert_allclose(fn(*values), expected(*values)) + + +def test_matmul(): + A = matrix("A") + B = matrix("B") + Av = np.arange(6.0).reshape(2, 3) + Bv = np.arange(6.0).reshape(3, 2) + fn = python_function([A, B], A @ B) + np.testing.assert_allclose(fn(Av, Bv), Av @ Bv) + + +def test_multiple_outputs(): + x = vector("x") + y = vector("y") + fn = python_function([x, y], [x + y, x - y]) + xv, yv = np.arange(4.0), np.arange(4.0) + 1 + out_add, out_sub = fn(xv, yv) + np.testing.assert_allclose(out_add, xv + yv) + np.testing.assert_allclose(out_sub, xv - yv) + + +def test_constant_in_graph(): + x = vector("x") + fn = python_function([x], x + pt.constant(np.ones(4))) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv + 1.0) + + +def test_constant_only_output(): + # Exercises fgraph_to_python's dedicated branch for outputs with no owner. + fn = python_function([], pt.constant(5.0)) + np.testing.assert_allclose(fn(), 5.0) + + +def test_shared_input(): + x = vector("x") + s = pytensor.shared(2.0) + fn = python_function([x], x * s) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv * 2.0) + s.set_value(3.0) + np.testing.assert_allclose(fn(xv), xv * 3.0) + + +def test_no_outputs(): + x = vector("x") + fn = python_function([x], [], on_unused_input="ignore") + assert fn(np.arange(4.0)) == [] + + +def test_fusion_excluded(): + x = vector("x") + y = vector("y") + fn = python_function([x, y], pt.exp(x) * y + pt.log(x) - y**2) + elemwise_nodes = [ + node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise) + ] + # Without fusion every scalar op stays its own vectorized Elemwise node, + # rather than collapsing into a single Composite. + assert len(elemwise_nodes) > 1 + assert not any(isinstance(node.op.scalar_op, Composite) for node in elemwise_nodes) + + +def test_cxx_only_excluded(): + # The `use_c_blas` rewrite (tagged cxx_only) would turn Gemv into the + # C-only CGemv, which has no perform and would strand a pure-Python linker. + # Excluding cxx_only keeps the perform-backed Gemv. + A = matrix("A") + x = vector("x") + y = vector("y") + fn = python_function([A, x, y], 2.0 * y + 3.0 * (A @ x)) + ops = {type(node.op) for node in fn.maker.fgraph.apply_nodes} + assert Gemv in ops + assert CGemv not in ops + Av, xv, yv = np.arange(6.0).reshape(2, 3), np.arange(3.0), np.arange(2.0) + np.testing.assert_allclose(fn(Av, xv, yv), 2.0 * yv + 3.0 * (Av @ xv)) + + +def test_ifelse_lazy(): + # IfElse has no perform (only a lazy make_thunk). The VM backend runs it via + # the fallback AND short-circuits it: the unused branch (which raises if + # evaluated) must not run. + c = pt.scalar("c") + x = vector("x") + boom = Assert("unused branch must not run")(x, pt.eq(x.sum(), -999.0)) + fn = python_function([c, x], ifelse(c > 0, x * 2.0, boom)) + np.testing.assert_allclose(fn(1.0, np.ones(3)), np.full(3, 2.0)) + + +class _PerformOnlyOp(Op): + __props__ = () + + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.square(inputs[0]) + + +def test_default_dispatch_uses_perform(): + x = vector("x") + fn = python_function([x], _PerformOnlyOp()(x)) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv**2) + + +class _DispatchedOp(Op): + __props__ = () + + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + +@python_funcify.register(_DispatchedOp) +def _python_funcify_dispatched_op(op, node=None, **kwargs): + # A fast path distinguishable from the identity perform above. + def impl(x): + return x + 100.0 + + return impl + + +def test_registered_dispatch_overrides_perform(): + x = vector("x") + fn = python_function([x], _DispatchedOp()(x)) + xv = np.arange(4.0) + np.testing.assert_allclose(fn(xv), xv + 100.0) + + +def test_dispatch_loaded_lazily(): + # Importing pytensor must not pull in the dispatch package; it should only + # load on the first PYTHON compile. + script = ( + "import sys, pytensor, pytensor.tensor as pt;" + "mod='pytensor.link.python.dispatch.basic';" + "assert mod not in sys.modules, 'loaded too early';" + "x=pt.vector('x');" + "pytensor.function([x], x+1, mode='PYTHON');" + "assert mod in sys.modules, 'not loaded after compile'" + ) + subprocess.run([sys.executable, "-c", script], check=True) From ece9045f3e4579f2e6d6c3a4f972c55b7241fe5b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 00:29:57 -0500 Subject: [PATCH 04/10] "py" -> PythonLinker, "perform" -> PerformLinker --- pytensor/compile/mode.py | 4 ++-- tests/link/python/test_basic.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 7d1538f4d7..820b75ec23 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -41,7 +41,8 @@ # Mode, it will be used as the key to retrieve the real linker in this # dictionary predefined_linkers = { - "py": PerformLinker(), # Use allow_gc PyTensor flag + "py": PythonLinker(), # Pure-Python backend with the python_funcify dispatch + "perform": PerformLinker(), # Per-node reference: runs every Op's perform method "c": CLinker(), # Don't support gc. so don't check allow_gc "c|py": OpWiseCLinker(), # Use allow_gc PyTensor flag "c|py_nogc": OpWiseCLinker(allow_gc=False), @@ -53,7 +54,6 @@ "pytorch": PytorchLinker(), "numba": NumbaLinker(), "mlx": MLXLinker(), - "python": PythonLinker(), } diff --git a/tests/link/python/test_basic.py b/tests/link/python/test_basic.py index 8c728d8f11..5fc9ccabb4 100644 --- a/tests/link/python/test_basic.py +++ b/tests/link/python/test_basic.py @@ -10,6 +10,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.ifelse import ifelse +from pytensor.link.basic import PerformLinker from pytensor.link.python.dispatch.basic import python_funcify from pytensor.link.python.linker import PythonLinker from pytensor.raise_op import Assert @@ -25,7 +26,9 @@ def python_function(inputs, outputs, **kwargs): def test_mode_and_linker_registered(): - assert isinstance(predefined_linkers["python"], PythonLinker) + # "py" is the pure-Python backend; "perform" is the per-node reference. + assert isinstance(predefined_linkers["py"], PythonLinker) + assert isinstance(predefined_linkers["perform"], PerformLinker) assert isinstance(get_mode("PYTHON").linker, PythonLinker) From 644eb6b9694d4741a503da8108397052cdeae88a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 01:31:35 -0500 Subject: [PATCH 05/10] Add Python backend Blockwise and Cholesky dispatches --- pytensor/link/python/dispatch/__init__.py | 4 +- pytensor/link/python/dispatch/blockwise.py | 17 +++++++ .../link/python/dispatch/linalg/__init__.py | 3 ++ .../dispatch/linalg/decomposition/__init__.py | 3 ++ .../dispatch/linalg/decomposition/cholesky.py | 32 +++++++++++++ .../tensor/linalg/decomposition/cholesky.py | 37 ++------------- tests/link/python/test_basic.py | 32 ++++++++++++- tests/link/python/test_blockwise.py | 46 +++++++++++++++++++ 8 files changed, 140 insertions(+), 34 deletions(-) create mode 100644 pytensor/link/python/dispatch/blockwise.py create mode 100644 pytensor/link/python/dispatch/linalg/__init__.py create mode 100644 pytensor/link/python/dispatch/linalg/decomposition/__init__.py create mode 100644 pytensor/link/python/dispatch/linalg/decomposition/cholesky.py create mode 100644 tests/link/python/test_blockwise.py diff --git a/pytensor/link/python/dispatch/__init__.py b/pytensor/link/python/dispatch/__init__.py index 45121ea5b8..b885f9a876 100644 --- a/pytensor/link/python/dispatch/__init__.py +++ b/pytensor/link/python/dispatch/__init__.py @@ -2,5 +2,7 @@ from pytensor.link.python.dispatch.basic import python_funcify # Load dispatch specializations -# (none yet — per-family override modules are imported here as they are added) +import pytensor.link.python.dispatch.blockwise +import pytensor.link.python.dispatch.linalg + # isort: on diff --git a/pytensor/link/python/dispatch/blockwise.py b/pytensor/link/python/dispatch/blockwise.py new file mode 100644 index 0000000000..87c84faba1 --- /dev/null +++ b/pytensor/link/python/dispatch/blockwise.py @@ -0,0 +1,17 @@ +import numpy as np + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.blockwise import Blockwise + + +@python_funcify.register(Blockwise) +def python_funcify_Blockwise(op, node=None, **kwargs): + core_node = op._create_dummy_core_node( + node.inputs, propagate_unbatched_core_inputs=True + ) + # Raises NotImplementedError when the core Op has no dispatch, which makes the + # whole Blockwise fall back to its (vectorized) perform. + core_fn = python_funcify(op.core_op, node=core_node) + + out_dtypes = [out.type.dtype for out in node.outputs] + return np.vectorize(core_fn, signature=op.signature, otypes=out_dtypes) diff --git a/pytensor/link/python/dispatch/linalg/__init__.py b/pytensor/link/python/dispatch/linalg/__init__.py new file mode 100644 index 0000000000..b30048fdac --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/__init__.py @@ -0,0 +1,3 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.decomposition +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/decomposition/__init__.py b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py new file mode 100644 index 0000000000..3d474af89d --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py @@ -0,0 +1,3 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.decomposition.cholesky +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py b/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py new file mode 100644 index 0000000000..9042f88336 --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/cholesky.py @@ -0,0 +1,32 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.decomposition.cholesky import Cholesky + + +@python_funcify.register(Cholesky) +def python_funcify_Cholesky(op, node=None, **kwargs): + lower = op.lower + overwrite_a = op.overwrite_a + (potrf,) = get_lapack_funcs(("potrf",), dtype=node.inputs[0].type.dtype) + + def cholesky(x): + if x.size == 0: + return np.empty_like(x) + + # potrf only honors overwrite_a for F-contiguous input; transpose a + # C-contiguous array to benefit from it. + c_contiguous_input = overwrite_a and x.flags["C_CONTIGUOUS"] + if c_contiguous_input: + x = x.T + factor, info = potrf(x, lower=not lower, overwrite_a=True, clean=True) + factor = factor.T + else: + factor, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) + + if info != 0: + factor[...] = np.nan + return factor + + return cholesky diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 0fa7a34b3b..816e218b79 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -48,39 +48,12 @@ def make_node(self, x): def perform(self, node, inputs, outputs): [x] = inputs [out] = outputs - - (potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) - - # Quick return for square empty array - if x.size == 0: - out[0] = np.empty_like(x, dtype=potrf.dtype) - return - - # Squareness check - if x.shape[0] != x.shape[1]: - raise ValueError( - f"Input array is expected to be square but has the shape: {x.shape}." + try: + out[0] = scipy_linalg.cholesky( + x, lower=self.lower, overwrite_a=self.overwrite_a, check_finite=False ) - - # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS - # If we have a `C_CONTIGUOUS` array we transpose to benefit from it - c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] - if c_contiguous_input: - x = x.T - lower = not self.lower - overwrite_a = True - else: - lower = self.lower - overwrite_a = self.overwrite_a - - c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) - - if info != 0: - c[...] = np.nan - out[0] = c - else: - # Transpose result if input was transposed - out[0] = c.T if c_contiguous_input else c + except scipy_linalg.LinAlgError: + out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) def pullback(self, inputs, outputs, gradients): """ diff --git a/tests/link/python/test_basic.py b/tests/link/python/test_basic.py index 5fc9ccabb4..45379d2b87 100644 --- a/tests/link/python/test_basic.py +++ b/tests/link/python/test_basic.py @@ -6,7 +6,7 @@ import pytensor import pytensor.tensor as pt -from pytensor.compile.mode import get_mode, predefined_linkers +from pytensor.compile.mode import Mode, get_mode, predefined_linkers from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.ifelse import ifelse @@ -25,6 +25,36 @@ def python_function(inputs, outputs, **kwargs): return pytensor.function(inputs, outputs, mode="PYTHON", **kwargs) +perform_mode = Mode(linker="perform", optimizer="fast_run") + + +def compare_python_and_perform( + graph_inputs, graph_outputs, test_inputs, assert_fn=None +): + """Compare a PYTHON-backend dispatch against the ``perform`` reference. + + Compiles ``graph_outputs`` under the Python backend (which exercises any + registered ``python_funcify`` dispatch) and under the ``perform`` linker (the + reference that runs every Op's ``perform``), then asserts they agree. + """ + if assert_fn is None: + + def assert_fn(a, b): + np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-10) + + py_fn = pytensor.function(graph_inputs, graph_outputs, mode="PYTHON") + perform_fn = pytensor.function(graph_inputs, graph_outputs, mode=perform_mode) + + py_res = py_fn(*test_inputs) + perform_res = perform_fn(*test_inputs) + for py_out, perform_out in zip( + py_res if isinstance(py_res, list) else [py_res], + perform_res if isinstance(perform_res, list) else [perform_res], + ): + assert_fn(py_out, perform_out) + return py_fn, py_res + + def test_mode_and_linker_registered(): # "py" is the pure-Python backend; "perform" is the per-node reference. assert isinstance(predefined_linkers["py"], PythonLinker) diff --git a/tests/link/python/test_blockwise.py b/tests/link/python/test_blockwise.py new file mode 100644 index 0000000000..e2bb0fc1f5 --- /dev/null +++ b/tests/link/python/test_blockwise.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.linalg.decomposition.cholesky import Cholesky +from tests.link.python.test_basic import compare_python_and_perform + + +def _has_cholesky(fgraph): + # The useless-Blockwise rewrite leaves a bare Cholesky for unbatched inputs. + for node in fgraph.apply_nodes: + op = node.op + if isinstance(op, Cholesky) or ( + isinstance(op, Blockwise) and isinstance(op.core_op, Cholesky) + ): + return True + return False + + +def _pd_matrices(shape, seed=0): + rng = np.random.default_rng(seed) + B = rng.standard_normal(shape) + A = B @ np.swapaxes(B, -1, -2) + shape[-1] * np.eye(shape[-1]) + return A.astype("float64") + + +@pytest.mark.parametrize("lower", [True, False]) +@pytest.mark.parametrize("shape", [(3, 3), (4, 3, 3), (2, 5, 3, 3)]) +def test_cholesky_dispatch(shape, lower): + A = pt.tensor("A", shape=(None,) * len(shape)) + out = pt.linalg.cholesky(A, lower=lower) + fn, _ = compare_python_and_perform([A], out, [_pd_matrices(shape)]) + assert _has_cholesky(fn.maker.fgraph) + + +def test_blockwise_falls_back_without_core_dispatch(): + # The general Solve has no python_funcify dispatch, so Blockwise must fall + # back to its (vectorized) perform and still match the reference. + A = pt.matrix("A") + b = pt.vector("b") + out = pt.linalg.solve(A, b) + rng = np.random.default_rng(2) + Av = rng.standard_normal((4, 4)) + 4 * np.eye(4) + bv = rng.standard_normal(4) + compare_python_and_perform([A, b], out, [Av, bv]) From df4f69c1fb521cf25dfa993f2914c373d94858f9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 01:32:03 -0500 Subject: [PATCH 06/10] Add Python backend SolveTriangular dispatch --- .../link/python/dispatch/linalg/__init__.py | 1 + .../dispatch/linalg/solvers/__init__.py | 3 ++ .../dispatch/linalg/solvers/triangular.py | 43 +++++++++++++++++++ pytensor/tensor/linalg/solvers/triangular.py | 39 +++-------------- tests/link/python/test_blockwise.py | 17 ++++++++ 5 files changed, 70 insertions(+), 33 deletions(-) create mode 100644 pytensor/link/python/dispatch/linalg/solvers/__init__.py create mode 100644 pytensor/link/python/dispatch/linalg/solvers/triangular.py diff --git a/pytensor/link/python/dispatch/linalg/__init__.py b/pytensor/link/python/dispatch/linalg/__init__.py index b30048fdac..f5165bc0f1 100644 --- a/pytensor/link/python/dispatch/linalg/__init__.py +++ b/pytensor/link/python/dispatch/linalg/__init__.py @@ -1,3 +1,4 @@ # isort: off import pytensor.link.python.dispatch.linalg.decomposition +import pytensor.link.python.dispatch.linalg.solvers # isort: on diff --git a/pytensor/link/python/dispatch/linalg/solvers/__init__.py b/pytensor/link/python/dispatch/linalg/solvers/__init__.py new file mode 100644 index 0000000000..d2a7665d4b --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/solvers/__init__.py @@ -0,0 +1,3 @@ +# isort: off +import pytensor.link.python.dispatch.linalg.solvers.triangular +# isort: on diff --git a/pytensor/link/python/dispatch/linalg/solvers/triangular.py b/pytensor/link/python/dispatch/linalg/solvers/triangular.py new file mode 100644 index 0000000000..6b9ac2c4bc --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/solvers/triangular.py @@ -0,0 +1,43 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.solvers.triangular import SolveTriangular + + +@python_funcify.register(SolveTriangular) +def python_funcify_SolveTriangular(op, node=None, **kwargs): + lower = op.lower + unit_diagonal = op.unit_diagonal + overwrite_b = op.overwrite_b + (trtrs,) = get_lapack_funcs(("trtrs",), dtype=node.outputs[0].type.dtype) + + def solve_triangular(A, b): + if b.size == 0: + return np.empty_like(b) + + if A.flags["F_CONTIGUOUS"]: + x, info = trtrs( + A, + b, + overwrite_b=overwrite_b, + lower=lower, + trans=0, + unitdiag=unit_diagonal, + ) + else: + # trtrs expects Fortran ordering, so solve the transposed system. + x, info = trtrs( + A.T, + b, + overwrite_b=overwrite_b, + lower=not lower, + trans=1, + unitdiag=unit_diagonal, + ) + + if info != 0: + x[...] = np.nan + return x + + return solve_triangular diff --git a/pytensor/tensor/linalg/solvers/triangular.py b/pytensor/tensor/linalg/solvers/triangular.py index e3adee0612..9d8eb1172c 100644 --- a/pytensor/tensor/linalg/solvers/triangular.py +++ b/pytensor/tensor/linalg/solvers/triangular.py @@ -31,44 +31,17 @@ def __init__(self, *, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - - if A.ndim != 2 or A.shape[0] != A.shape[1]: - raise ValueError("expected square matrix") - - if A.shape[0] != b.shape[0]: - raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") - - (trtrs,) = scipy_linalg.get_lapack_funcs(("trtrs",), (A, b)) - - # Quick return for empty arrays - if b.size == 0: - outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) - return - - if A.flags["F_CONTIGUOUS"]: - x, info = trtrs( + try: + outputs[0][0] = scipy_linalg.solve_triangular( A, b, - overwrite_b=self.overwrite_b, lower=self.lower, - trans=0, - unitdiag=self.unit_diagonal, - ) - else: - # transposed system is solved since trtrs expects Fortran ordering - x, info = trtrs( - A.T, - b, + unit_diagonal=self.unit_diagonal, overwrite_b=self.overwrite_b, - lower=not self.lower, - trans=1, - unitdiag=self.unit_diagonal, + check_finite=False, ) - - if info != 0: - x[...] = np.nan - - outputs[0][0] = x + except scipy_linalg.LinAlgError: + outputs[0][0] = np.full(b.shape, np.nan, dtype=node.outputs[0].type.dtype) def pullback(self, inputs, outputs, output_gradients): res = super().pullback(inputs, outputs, output_gradients) diff --git a/tests/link/python/test_blockwise.py b/tests/link/python/test_blockwise.py index e2bb0fc1f5..16e8bbb694 100644 --- a/tests/link/python/test_blockwise.py +++ b/tests/link/python/test_blockwise.py @@ -34,6 +34,23 @@ def test_cholesky_dispatch(shape, lower): assert _has_cholesky(fn.maker.fgraph) +@pytest.mark.parametrize("lower", [True, False]) +@pytest.mark.parametrize("b_shape", [(4,), (4, 2)]) +@pytest.mark.parametrize("batch", [(), (3,)]) +def test_solve_triangular_dispatch(batch, b_shape, lower): + rng = np.random.default_rng(1) + Av = np.tril(rng.standard_normal((*batch, 4, 4))) + 4 * np.eye(4) + if not lower: + Av = np.swapaxes(Av, -1, -2) + bv = rng.standard_normal((*batch, *b_shape)) + A = pt.tensor("A", shape=(None,) * Av.ndim) + b = pt.tensor("b", shape=(None,) * bv.ndim) + out = pt.linalg.solve_triangular(A, b, lower=lower, b_ndim=len(b_shape)) + compare_python_and_perform( + [A, b], out, [Av.astype("float64"), bv.astype("float64")] + ) + + def test_blockwise_falls_back_without_core_dispatch(): # The general Solve has no python_funcify dispatch, so Blockwise must fall # back to its (vectorized) perform and still match the reference. From bc08d09635139ee82e43386d031d28010a7bb1cc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 01:32:35 -0500 Subject: [PATCH 07/10] Add Python backend QR dispatch and fix QR raw output shapes The "raw" mode H factor and R were declared (M,M) and (M,N) but are actually (M,N) and (K,N), which np.vectorize rejects for M < N. --- .../dispatch/linalg/decomposition/__init__.py | 1 + .../dispatch/linalg/decomposition/qr.py | 54 ++++++++++++ pytensor/tensor/linalg/decomposition/qr.py | 87 ++++--------------- tests/link/python/test_blockwise.py | 10 +++ 4 files changed, 80 insertions(+), 72 deletions(-) create mode 100644 pytensor/link/python/dispatch/linalg/decomposition/qr.py diff --git a/pytensor/link/python/dispatch/linalg/decomposition/__init__.py b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py index 3d474af89d..c1bfeab072 100644 --- a/pytensor/link/python/dispatch/linalg/decomposition/__init__.py +++ b/pytensor/link/python/dispatch/linalg/decomposition/__init__.py @@ -1,3 +1,4 @@ # isort: off import pytensor.link.python.dispatch.linalg.decomposition.cholesky +import pytensor.link.python.dispatch.linalg.decomposition.qr # isort: on diff --git a/pytensor/link/python/dispatch/linalg/decomposition/qr.py b/pytensor/link/python/dispatch/linalg/decomposition/qr.py new file mode 100644 index 0000000000..903a3b6c9b --- /dev/null +++ b/pytensor/link/python/dispatch/linalg/decomposition/qr.py @@ -0,0 +1,54 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.link.python.dispatch.basic import python_funcify +from pytensor.tensor.linalg.decomposition.qr import QR + + +@python_funcify.register(QR) +def python_funcify_QR(op, node=None, **kwargs): + mode = op.mode + pivoting = op.pivoting + overwrite_a = op.overwrite_a + call_and_get_lwork = op._call_and_get_lwork + + def qr(x): + M, N = x.shape + + if pivoting: + (geqp3,) = get_lapack_funcs(("geqp3",), (x,)) + factor, jpvt, tau, *_ = call_and_get_lwork( + geqp3, x, lwork=-1, overwrite_a=overwrite_a + ) + jpvt -= 1 # geqp3 returns 1-based indices + else: + (geqrf,) = get_lapack_funcs(("geqrf",), (x,)) + factor, tau, *_ = call_and_get_lwork( + geqrf, x, lwork=-1, overwrite_a=overwrite_a + ) + + if mode not in ("economic", "raw") or M < N: + R = np.triu(factor) + else: + R = np.triu(factor[:N, :]) + + if mode == "r": + return (R, jpvt) if pivoting else R + if mode == "raw": + return (factor, tau, R, jpvt) if pivoting else (factor, tau, R) + + (orgqr,) = get_lapack_funcs(("orgqr",), (factor,)) + if M < N: + Q, *_ = call_and_get_lwork( + orgqr, factor[:, :M], tau, lwork=-1, overwrite_a=1 + ) + elif mode == "economic": + Q, *_ = call_and_get_lwork(orgqr, factor, tau, lwork=-1, overwrite_a=1) + else: + square = np.empty((M, M), dtype=factor.dtype.char) + square[:, :N] = factor + Q, *_ = call_and_get_lwork(orgqr, square, tau, lwork=-1, overwrite_a=1) + + return (Q, R, jpvt) if pivoting else (Q, R) + + return qr diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 9e9270259d..5ac6f5626f 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -1,7 +1,5 @@ from typing import Literal, cast -import numpy as np - from pytensor import ifelse from pytensor import tensor as pt from pytensor.gradient import DisconnectedType, disconnected_type @@ -51,7 +49,7 @@ def __init__( case "r": self.gufunc_signature = "(m,n)->(m,n)" case "raw": - self.gufunc_signature = "(m,n)->(n,m),(k),(m,n)" + self.gufunc_signature = "(m,n)->(m,n),(k),(k,n)" case _: raise ValueError( f"Invalid mode '{mode}'. Supported modes are 'full', 'economic', 'r', and 'raw'." @@ -91,9 +89,9 @@ def make_node(self, x): ] case "raw": outputs = [ - tensor(shape=(M, M), dtype=out_dtype), - tensor(shape=(K,), dtype=out_dtype), tensor(shape=(M, N), dtype=out_dtype), + tensor(shape=(K,), dtype=out_dtype), + tensor(shape=(K, N), dtype=out_dtype), ] case _: raise NotImplementedError @@ -124,9 +122,9 @@ def infer_shape(self, fgraph, node, shapes): case "r": R_shape = (M, N) case "raw": - Q_shape = (M, M) # Actually this is H in this case + Q_shape = (M, N) # Actually this is H in this case tau_shape = (K,) - R_shape = (M, N) + R_shape = (K, N) if self.pivoting: P_shape = (N,) @@ -153,72 +151,17 @@ def _call_and_get_lwork(self, fn, *args, lwork, **kwargs): def perform(self, node, inputs, outputs): (x,) = inputs - M, N = x.shape - - if self.pivoting: - (geqp3,) = scipy_linalg.get_lapack_funcs(("geqp3",), (x,)) - qr, jpvt, tau, *_work_info = self._call_and_get_lwork( - geqp3, x, lwork=-1, overwrite_a=self.overwrite_a - ) - jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1 - else: - (geqrf,) = scipy_linalg.get_lapack_funcs(("geqrf",), (x,)) - qr, tau, *_work_info = self._call_and_get_lwork( - geqrf, x, lwork=-1, overwrite_a=self.overwrite_a - ) - - if self.mode not in ["economic", "raw"] or M < N: - R = np.triu(qr) - else: - R = np.triu(qr[:N, :]) - - if self.mode == "r" and self.pivoting: - outputs[0][0] = R - outputs[1][0] = jpvt - return - - elif self.mode == "r": - outputs[0][0] = R - return - - elif self.mode == "raw" and self.pivoting: - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R - outputs[3][0] = jpvt - return - - elif self.mode == "raw": - outputs[0][0] = qr - outputs[1][0] = tau - outputs[2][0] = R - return - - (gor_un_gqr,) = scipy_linalg.get_lapack_funcs(("orgqr",), (qr,)) - - if M < N: - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qr[:, :M], tau, lwork=-1, overwrite_a=1 - ) - elif self.mode == "economic": - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qr, tau, lwork=-1, overwrite_a=1 - ) + result = scipy_linalg.qr( + x, mode=self.mode, pivoting=self.pivoting, overwrite_a=self.overwrite_a + ) + if self.mode == "raw": + # scipy nests the Householder reflectors as ((qr, tau), R[, jpvt]). + (factor, tau), *rest = result + values = [factor, tau, *rest] else: - t = qr.dtype.char - qqr = np.empty((M, M), dtype=t) - qqr[:, :N] = qr - - # Always overwite qqr -- it's a meaningless intermediate value - Q, _work, _info = self._call_and_get_lwork( - gor_un_gqr, qqr, tau, lwork=-1, overwrite_a=1 - ) - - outputs[0][0] = Q - outputs[1][0] = R - - if self.pivoting: - outputs[2][0] = jpvt + values = list(result) + for storage, value in zip(outputs, values): + storage[0] = value def pullback(self, inputs, outputs, output_grads): """ diff --git a/tests/link/python/test_blockwise.py b/tests/link/python/test_blockwise.py index 16e8bbb694..75034a841a 100644 --- a/tests/link/python/test_blockwise.py +++ b/tests/link/python/test_blockwise.py @@ -51,6 +51,16 @@ def test_solve_triangular_dispatch(batch, b_shape, lower): ) +@pytest.mark.parametrize("pivoting", [False, True]) +@pytest.mark.parametrize("mode", ["full", "economic", "r", "raw"]) +@pytest.mark.parametrize("shape", [(4, 3), (3, 4), (5, 3, 4)]) +def test_qr_dispatch(shape, mode, pivoting): + rng = np.random.default_rng(0) + A = pt.tensor("A", shape=(None,) * len(shape)) + out = pt.linalg.qr(A, mode=mode, pivoting=pivoting) + compare_python_and_perform([A], out, [rng.standard_normal(shape)]) + + def test_blockwise_falls_back_without_core_dispatch(): # The general Solve has no python_funcify dispatch, so Blockwise must fall # back to its (vectorized) perform and still match the reference. From c2fa96d22cb41726a7cf736becd96538e1fb4f22 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 16:03:42 -0500 Subject: [PATCH 08/10] Route VMLinker py thunks through the python_funcify dispatch Promotes the per-subclass routing PythonLinker did into the VMLinker default, so every pure-Python VM path (pyvm, FAST_COMPILE, cxx-less vm) prefers a registered python_funcify before falling back to perform. --- pytensor/link/vm.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 4a53c7c216..2938e09aa2 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -1206,9 +1206,26 @@ def make_vm( def _make_node_thunk(self, node, storage_map, compute_map, impl): """Create the thunk for a single node. - Subclasses override this to intercept thunk creation (e.g. to consult a - dispatch registry) before falling back to ``Op.make_thunk``. + On the pure-Python path (``impl == "py"``) a registered + ``python_funcify`` implementation is wrapped into a thunk; otherwise the + node's ``Op.make_thunk`` is used (which also covers lazy ops like + ``IfElse``). """ + if impl == "py": + from pytensor.link.python.dispatch.basic import ( + make_node_thunk_with_python_dispatch, + ) + + return make_node_thunk_with_python_dispatch( + node, + storage_map, + compute_map, + fallback=self._make_perform_thunk, + impl=impl, + ) + return self._make_perform_thunk(node, storage_map, compute_map, impl) + + def _make_perform_thunk(self, node, storage_map, compute_map, impl): # no-recycling is done at each VM.__call__, so there is no need to cause # duplicate C code by passing no_recycling here. return node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) From 72bbd87ccd2b67d087fb9ed7e340754fb6f0d663 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 17:05:10 -0500 Subject: [PATCH 09/10] Add whole-graph JIT linker (pyjit), keep py the robust VM PythonLinker is now a whole-graph JITLinker registered as pyjit; py is a plain VMLinker -- the robust per-node Python reference (lazy ops, perform fallback). The redundant perform linker is dropped. --- pytensor/compile/mode.py | 16 +++-- pytensor/link/python/dispatch/basic.py | 72 +++++++++++++++++----- pytensor/link/python/dispatch/blockwise.py | 4 +- pytensor/link/python/linker.py | 67 ++++++-------------- pytensor/link/vm.py | 24 +++++--- tests/compile/test_mode.py | 7 +-- tests/link/python/test_basic.py | 46 +++++++------- tests/link/python/test_blockwise.py | 12 ++-- 8 files changed, 133 insertions(+), 115 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 820b75ec23..b145593fc5 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -24,7 +24,7 @@ SequenceDB, TopoDB, ) -from pytensor.link.basic import Linker, PerformLinker +from pytensor.link.basic import Linker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker from pytensor.link.mlx.linker import MLXLinker @@ -41,8 +41,10 @@ # Mode, it will be used as the key to retrieve the real linker in this # dictionary predefined_linkers = { - "py": PythonLinker(), # Pure-Python backend with the python_funcify dispatch - "perform": PerformLinker(), # Per-node reference: runs every Op's perform method + "py": VMLinker( # Robust per-node Python VM over python_funcify; handles lazy ops + use_cloop=False, c_thunks=False + ), + "pyjit": PythonLinker(), # Whole-graph python_funcify composition; no lazy ops "c": CLinker(), # Don't support gc. so don't check allow_gc "c|py": OpWiseCLinker(), # Use allow_gc PyTensor flag "c|py_nogc": OpWiseCLinker(allow_gc=False), @@ -479,8 +481,13 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ) PYTHON = Mode( + VMLinker(use_cloop=False, c_thunks=False), + RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"), +) + +PYJIT = Mode( PythonLinker(), - RewriteDatabaseQuery(include=["fast_run"]), + RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"), ) FAST_COMPILE = Mode( @@ -503,6 +510,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "PYTORCH": PYTORCH, "MLX": MLX, "PYTHON": PYTHON, + "PYJIT": PYJIT, } _CACHED_RUNTIME_MODES: dict[Any, Mode] = {} diff --git a/pytensor/link/python/dispatch/basic.py b/pytensor/link/python/dispatch/basic.py index 92a39bcc09..529eaf312a 100644 --- a/pytensor/link/python/dispatch/basic.py +++ b/pytensor/link/python/dispatch/basic.py @@ -1,5 +1,8 @@ from functools import singledispatch +from pytensor.graph.fg import AbstractFunctionGraph +from pytensor.link.utils import fgraph_to_python + @singledispatch def python_funcify(op, node=None, **kwargs): @@ -10,18 +13,53 @@ def python_funcify(op, node=None, **kwargs): to override an `Op`'s ``perform`` with a faster numpy/scipy path on the Python backend. - Unregistered ops raise `NotImplementedError`, signalling the linker to fall - back to ``perform`` (via ``Op.make_thunk(impl="py")``). + Unregistered ops raise `NotImplementedError`. The ``py`` (VM) linker catches + this and falls back to ``Op.make_thunk(impl="py")``; the whole-graph ``pyjit`` + linker catches it and falls back to a ``perform`` wrapper. """ raise NotImplementedError( f"No python_funcify implementation registered for {type(op).__name__}" ) +def _perform_wrapper(op, node): + """Wrap an `Op`'s ``perform`` into a `python_funcify`-style callable.""" + n_outputs = len(node.outputs) + single_output = n_outputs == 1 + + def perform(*inputs): + output_storage = [[None] for _ in range(n_outputs)] + op.perform(node, list(inputs), output_storage) + if single_output: + return output_storage[0][0] + return tuple(storage[0] for storage in output_storage) + + return perform + + +def _funcify_or_perform(op, node=None, **kwargs): + try: + return python_funcify(op, node=node) + except NotImplementedError: + return _perform_wrapper(op, node) + + +@python_funcify.register(AbstractFunctionGraph) +def python_funcify_FunctionGraph( + fgraph, node=None, fgraph_name="python_funcified_fgraph", **kwargs +): + return fgraph_to_python( + fgraph, + op_conversion_fn=_funcify_or_perform, + fgraph_name=fgraph_name, + **kwargs, + ) + + def make_node_thunk_with_python_dispatch( - node, storage_map, compute_map, *, fallback, impl + node, storage_map, compute_map, *, fallback, implementation ): - """Build a per-node thunk, preferring a registered `python_funcify` impl. + """Build a per-node thunk, preferring a registered `python_funcify` implementation. When `python_funcify` has a specialization for ``node.op``, its callable is wrapped into a thunk that reads inputs from and writes outputs to @@ -31,28 +69,30 @@ def make_node_thunk_with_python_dispatch( try: fn = python_funcify(node.op, node=node) except NotImplementedError: - return fallback(node, storage_map, compute_map, impl) + return fallback(node, storage_map, compute_map, implementation) return _wrap_callable_as_thunk(fn, node, storage_map, compute_map) def _wrap_callable_as_thunk(fn, node, storage_map, compute_map): - input_storage = [storage_map[v] for v in node.inputs] - output_compute = [compute_map[v] for v in node.outputs] + input_storage = [storage_map[variable] for variable in node.inputs] + output_compute = [compute_map[variable] for variable in node.outputs] if len(node.outputs) == 1: - [out_storage] = (storage_map[v] for v in node.outputs) + [output] = (storage_map[variable] for variable in node.outputs) - def thunk(fn=fn, inputs=input_storage, out=out_storage, cm=output_compute): - out[0] = fn(*(inp[0] for inp in inputs)) - cm[0][0] = True + def thunk(fn=fn, inputs=input_storage, output=output, compute=output_compute): + output[0] = fn(*(inp[0] for inp in inputs)) + compute[0][0] = True else: - output_storage = [storage_map[v] for v in node.outputs] + output_storage = [storage_map[variable] for variable in node.outputs] - def thunk(fn=fn, inputs=input_storage, outs=output_storage, cm=output_compute): - for storage, value in zip(outs, fn(*(inp[0] for inp in inputs))): - storage[0] = value - for entry in cm: + def thunk( + fn=fn, inputs=input_storage, outputs=output_storage, compute=output_compute + ): + for output, value in zip(outputs, fn(*(inp[0] for inp in inputs))): + output[0] = value + for entry in compute: entry[0] = True thunk.lazy = False diff --git a/pytensor/link/python/dispatch/blockwise.py b/pytensor/link/python/dispatch/blockwise.py index 87c84faba1..354cb548bb 100644 --- a/pytensor/link/python/dispatch/blockwise.py +++ b/pytensor/link/python/dispatch/blockwise.py @@ -13,5 +13,5 @@ def python_funcify_Blockwise(op, node=None, **kwargs): # whole Blockwise fall back to its (vectorized) perform. core_fn = python_funcify(op.core_op, node=core_node) - out_dtypes = [out.type.dtype for out in node.outputs] - return np.vectorize(core_fn, signature=op.signature, otypes=out_dtypes) + output_dtypes = [output.type.dtype for output in node.outputs] + return np.vectorize(core_fn, signature=op.signature, otypes=output_dtypes) diff --git a/pytensor/link/python/linker.py b/pytensor/link/python/linker.py index 3f8db83c42..e50d64b72f 100644 --- a/pytensor/link/python/linker.py +++ b/pytensor/link/python/linker.py @@ -1,52 +1,25 @@ -from pytensor.link.vm import VMLinker +from pytensor.link.basic import JITLinker -class PythonLinker(VMLinker): - """A pure-Python `VMLinker` that runs each node through the `python_funcify` registry. +class PythonLinker(JITLinker): + """Compose a `FunctionGraph` into a single pure-Python function. - Per node, a registered `python_funcify` implementation (a fast numpy/scipy - callable) is wrapped into a thunk; unregistered ops fall back to their - ``perform`` method via ``Op.make_thunk(impl="py")``. Lazy ops such as - ``IfElse`` fall through to their own thunks, so the VM still short-circuits - them. Fusion is excluded because fused ``Composite`` loops run slower than - vectorized numpy on this backend. + The whole graph is turned into one straight-line Python function by + `fgraph_to_python`, dispatching each `Op` through the `python_funcify` + registry (falling back to ``perform`` for unregistered ops). There is no + compilation step, so `jit_compile` is the identity. """ - def __init__( - self, - allow_gc=None, - use_cloop=False, - callback=None, - callback_input=None, - lazy=None, - schedule=None, - c_thunks=None, - allow_partial_eval=None, - ): - # The Python backend never emits C: per-node Python thunks, Python VM. - super().__init__( - allow_gc=allow_gc, - use_cloop=False, - callback=callback, - callback_input=callback_input, - lazy=lazy, - schedule=schedule, - c_thunks=False, - allow_partial_eval=allow_partial_eval, - ) - # ``c_thunks=False`` already gives ("minimum_compile", "py_only") / - # ("cxx_only",); add fusion for the numpy backend. - self.incompatible_rewrites = ("cxx_only", "fusion") - - def _make_node_thunk(self, node, storage_map, compute_map, impl): - from pytensor.link.python.dispatch.basic import ( - make_node_thunk_with_python_dispatch, - ) - - return make_node_thunk_with_python_dispatch( - node, - storage_map, - compute_map, - fallback=super()._make_node_thunk, - impl=impl, - ) + required_rewrites = ("minimum_compile", "py_only") + incompatible_rewrites = ("cxx_only",) + + def fgraph_convert(self, fgraph, **kwargs): + from pytensor.link.python.dispatch.basic import python_funcify + + return python_funcify(fgraph, **kwargs) + + def jit_compile(self, fn): + return fn + + def create_thunk_inputs(self, storage_map): + return [storage_map[n] for n in self.fgraph.inputs] diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 2938e09aa2..ced605600e 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -1203,15 +1203,15 @@ def make_vm( ) return vm - def _make_node_thunk(self, node, storage_map, compute_map, impl): + def _make_node_thunk(self, node, storage_map, compute_map, implementation): """Create the thunk for a single node. - On the pure-Python path (``impl == "py"``) a registered + On the pure-Python path (``implementation == "py"``) a registered ``python_funcify`` implementation is wrapped into a thunk; otherwise the node's ``Op.make_thunk`` is used (which also covers lazy ops like ``IfElse``). """ - if impl == "py": + if implementation == "py": from pytensor.link.python.dispatch.basic import ( make_node_thunk_with_python_dispatch, ) @@ -1221,14 +1221,16 @@ def _make_node_thunk(self, node, storage_map, compute_map, impl): storage_map, compute_map, fallback=self._make_perform_thunk, - impl=impl, + implementation=implementation, ) - return self._make_perform_thunk(node, storage_map, compute_map, impl) + return self._make_perform_thunk(node, storage_map, compute_map, implementation) - def _make_perform_thunk(self, node, storage_map, compute_map, impl): + def _make_perform_thunk(self, node, storage_map, compute_map, implementation): # no-recycling is done at each VM.__call__, so there is no need to cause # duplicate C code by passing no_recycling here. - return node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) + return node.op.make_thunk( + node, storage_map, compute_map, [], impl=implementation + ) def make_all( self, @@ -1251,14 +1253,16 @@ def make_all( t0 = time.perf_counter() linker_make_thunk_time = {} - impl = None + implementation = None if self.c_thunks is False: - impl = "py" + implementation = "py" for node in order: try: thunk_start = time.perf_counter() thunks.append( - self._make_node_thunk(node, storage_map, compute_map, impl) + self._make_node_thunk( + node, storage_map, compute_map, implementation + ) ) linker_make_thunk_time[node] = time.perf_counter() - thunk_start if not hasattr(thunks[-1], "lazy"): diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index a4a699fe21..ddcf5c0d7d 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -99,11 +99,10 @@ def test_modes(self): # regression check: # there should be # - NumbaLinker - # - `VMLinker` - # - OpWiseCLinker (FAST_RUN) - # - PerformLinker (FAST_COMPILE) + # - `VMLinker` (py, vm, cvm, FAST_COMPILE) + # - OpWiseCLinker (c|py) # - DebugMode's Linker (DEBUG_MODE) - assert 5 == len(set(linker_classes_involved)) + assert 4 == len(set(linker_classes_involved)) class TestOldModesProblem: diff --git a/tests/link/python/test_basic.py b/tests/link/python/test_basic.py index 45379d2b87..4826f03d67 100644 --- a/tests/link/python/test_basic.py +++ b/tests/link/python/test_basic.py @@ -6,13 +6,13 @@ import pytensor import pytensor.tensor as pt -from pytensor.compile.mode import Mode, get_mode, predefined_linkers +from pytensor.compile.mode import get_mode, predefined_linkers from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.ifelse import ifelse -from pytensor.link.basic import PerformLinker from pytensor.link.python.dispatch.basic import python_funcify from pytensor.link.python.linker import PythonLinker +from pytensor.link.vm import VMLinker from pytensor.raise_op import Assert from pytensor.scalar.basic import Composite from pytensor.tensor.blas import Gemv @@ -25,17 +25,12 @@ def python_function(inputs, outputs, **kwargs): return pytensor.function(inputs, outputs, mode="PYTHON", **kwargs) -perform_mode = Mode(linker="perform", optimizer="fast_run") +def compare_py_and_pyjit(graph_inputs, graph_outputs, test_inputs, assert_fn=None): + """Compare the per-node ``py`` (VM) backend against the whole-graph ``pyjit``. - -def compare_python_and_perform( - graph_inputs, graph_outputs, test_inputs, assert_fn=None -): - """Compare a PYTHON-backend dispatch against the ``perform`` reference. - - Compiles ``graph_outputs`` under the Python backend (which exercises any - registered ``python_funcify`` dispatch) and under the ``perform`` linker (the - reference that runs every Op's ``perform``), then asserts they agree. + Both run the same ``python_funcify`` dispatches, so this checks that the VM + per-node wiring and the JIT whole-graph composition agree. Only valid for + graphs without lazy ops, which the JIT cannot compose. """ if assert_fn is None: @@ -43,23 +38,24 @@ def assert_fn(a, b): np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-10) py_fn = pytensor.function(graph_inputs, graph_outputs, mode="PYTHON") - perform_fn = pytensor.function(graph_inputs, graph_outputs, mode=perform_mode) + pyjit_fn = pytensor.function(graph_inputs, graph_outputs, mode="PYJIT") py_res = py_fn(*test_inputs) - perform_res = perform_fn(*test_inputs) - for py_out, perform_out in zip( + pyjit_res = pyjit_fn(*test_inputs) + for py_out, pyjit_out in zip( py_res if isinstance(py_res, list) else [py_res], - perform_res if isinstance(perform_res, list) else [perform_res], + pyjit_res if isinstance(pyjit_res, list) else [pyjit_res], ): - assert_fn(py_out, perform_out) + assert_fn(py_out, pyjit_out) return py_fn, py_res def test_mode_and_linker_registered(): - # "py" is the pure-Python backend; "perform" is the per-node reference. - assert isinstance(predefined_linkers["py"], PythonLinker) - assert isinstance(predefined_linkers["perform"], PerformLinker) - assert isinstance(get_mode("PYTHON").linker, PythonLinker) + # "py" is the robust per-node VM backend; "pyjit" the whole-graph JIT. + assert isinstance(predefined_linkers["py"], VMLinker) + assert isinstance(predefined_linkers["pyjit"], PythonLinker) + assert isinstance(get_mode("PYTHON").linker, VMLinker) + assert isinstance(get_mode("PYJIT").linker, PythonLinker) @pytest.mark.parametrize( @@ -119,7 +115,7 @@ def test_constant_in_graph(): def test_constant_only_output(): - # Exercises fgraph_to_python's dedicated branch for outputs with no owner. + # An output with no owner (a bare constant) must still be returned. fn = python_function([], pt.constant(5.0)) np.testing.assert_allclose(fn(), 5.0) @@ -169,9 +165,9 @@ def test_cxx_only_excluded(): def test_ifelse_lazy(): - # IfElse has no perform (only a lazy make_thunk). The VM backend runs it via - # the fallback AND short-circuits it: the unused branch (which raises if - # evaluated) must not run. + # IfElse has no perform (only a lazy make_thunk). The py (VM) backend runs it + # via the fallback AND short-circuits it: the unused branch (which raises if + # evaluated) must not run. The whole-graph pyjit backend cannot do this. c = pt.scalar("c") x = vector("x") boom = Assert("unused branch must not run")(x, pt.eq(x.sum(), -999.0)) diff --git a/tests/link/python/test_blockwise.py b/tests/link/python/test_blockwise.py index 75034a841a..74c2259cbf 100644 --- a/tests/link/python/test_blockwise.py +++ b/tests/link/python/test_blockwise.py @@ -4,7 +4,7 @@ import pytensor.tensor as pt from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg.decomposition.cholesky import Cholesky -from tests.link.python.test_basic import compare_python_and_perform +from tests.link.python.test_basic import compare_py_and_pyjit def _has_cholesky(fgraph): @@ -30,7 +30,7 @@ def _pd_matrices(shape, seed=0): def test_cholesky_dispatch(shape, lower): A = pt.tensor("A", shape=(None,) * len(shape)) out = pt.linalg.cholesky(A, lower=lower) - fn, _ = compare_python_and_perform([A], out, [_pd_matrices(shape)]) + fn, _ = compare_py_and_pyjit([A], out, [_pd_matrices(shape)]) assert _has_cholesky(fn.maker.fgraph) @@ -46,9 +46,7 @@ def test_solve_triangular_dispatch(batch, b_shape, lower): A = pt.tensor("A", shape=(None,) * Av.ndim) b = pt.tensor("b", shape=(None,) * bv.ndim) out = pt.linalg.solve_triangular(A, b, lower=lower, b_ndim=len(b_shape)) - compare_python_and_perform( - [A, b], out, [Av.astype("float64"), bv.astype("float64")] - ) + compare_py_and_pyjit([A, b], out, [Av.astype("float64"), bv.astype("float64")]) @pytest.mark.parametrize("pivoting", [False, True]) @@ -58,7 +56,7 @@ def test_qr_dispatch(shape, mode, pivoting): rng = np.random.default_rng(0) A = pt.tensor("A", shape=(None,) * len(shape)) out = pt.linalg.qr(A, mode=mode, pivoting=pivoting) - compare_python_and_perform([A], out, [rng.standard_normal(shape)]) + compare_py_and_pyjit([A], out, [rng.standard_normal(shape)]) def test_blockwise_falls_back_without_core_dispatch(): @@ -70,4 +68,4 @@ def test_blockwise_falls_back_without_core_dispatch(): rng = np.random.default_rng(2) Av = rng.standard_normal((4, 4)) + 4 * np.eye(4) bv = rng.standard_normal(4) - compare_python_and_perform([A, b], out, [Av, bv]) + compare_py_and_pyjit([A, b], out, [Av, bv]) From 29b972b8f3c179a4d0119b375db3ce052d7af713 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 17:05:23 -0500 Subject: [PATCH 10/10] Default Op.perform to python_funcify; make Cholesky perform-less The base Op.perform now delegates to the python_funcify dispatch, so an Op's numeric behaviour can live entirely in the registry. Cholesky keeps only its symbolic definition. --- pytensor/graph/op.py | 15 ++++++++++++++- pytensor/tensor/linalg/decomposition/cholesky.py | 11 ----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index ebfa067b71..dd1fa42aac 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -454,7 +454,6 @@ def R_op( return [None if isinstance(r.type, DisconnectedType) else r for r in result] # type: ignore[misc] raise NotImplementedError() - @abstractmethod def perform( self, node: Apply, @@ -489,7 +488,21 @@ def perform( An `Op` is free to reuse `output_storage` as it sees fit, or to discard it and allocate new memory. + The default implementation runs the `Op`'s ``python_funcify`` dispatch: + it builds the callable, runs it on the inputs, and writes the results + into ``output_storage``. An `Op` only needs to override this to keep a + bespoke ``perform``; otherwise its numeric behaviour lives entirely in + the ``python_funcify`` registry. """ + from pytensor.link.python.dispatch.basic import python_funcify + + fn = python_funcify(self, node=node) + results = fn(*inputs) + if len(output_storage) == 1: + output_storage[0][0] = results + else: + for output, value in zip(output_storage, results, strict=True): + output[0] = value def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: """Determine whether constant folding should be performed for the given node. diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 816e218b79..95babfb6e9 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -10,7 +10,6 @@ from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.linalg._lazy import scipy_linalg from pytensor.tensor.linalg.dtype_utils import linalg_output_dtype from pytensor.tensor.type import tensor @@ -45,16 +44,6 @@ def make_node(self, x): dtype = linalg_output_dtype(x.type.dtype) return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) - def perform(self, node, inputs, outputs): - [x] = inputs - [out] = outputs - try: - out[0] = scipy_linalg.cholesky( - x, lower=self.lower, overwrite_a=self.overwrite_a, check_finite=False - ) - except scipy_linalg.LinAlgError: - out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) - def pullback(self, inputs, outputs, gradients): """ Cholesky decomposition reverse-mode gradient update.