Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
SequenceDB,
TopoDB,
)
from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.basic import Linker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.python.linker import PythonLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker

Expand All @@ -40,7 +41,10 @@
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers = {
"py": PerformLinker(), # Use allow_gc PyTensor flag
"py": VMLinker( # Robust per-node Python VM over python_funcify; handles lazy ops
use_cloop=False, c_thunks=False
),
"pyjit": PythonLinker(), # Whole-graph python_funcify composition; no lazy ops
"c": CLinker(), # Don't support gc. so don't check allow_gc
"c|py": OpWiseCLinker(), # Use allow_gc PyTensor flag
"c|py_nogc": OpWiseCLinker(allow_gc=False),
Expand Down Expand Up @@ -476,6 +480,16 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
RewriteDatabaseQuery(include=["fast_run", "mlx"]),
)

PYTHON = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"),
)

PYJIT = Mode(
PythonLinker(),
RewriteDatabaseQuery(include=["fast_run"]).excluding("fusion"),
)

FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
Expand All @@ -495,6 +509,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
"PYTHON": PYTHON,
"PYJIT": PYJIT,
}

_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
Expand Down
15 changes: 14 additions & 1 deletion pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ def R_op(
return [None if isinstance(r.type, DisconnectedType) else r for r in result] # type: ignore[misc]
raise NotImplementedError()

@abstractmethod
def perform(
self,
node: Apply,
Expand Down Expand Up @@ -489,7 +488,21 @@ def perform(
An `Op` is free to reuse `output_storage` as it sees fit, or to
discard it and allocate new memory.

The default implementation runs the `Op`'s ``python_funcify`` dispatch:
it builds the callable, runs it on the inputs, and writes the results
into ``output_storage``. An `Op` only needs to override this to keep a
bespoke ``perform``; otherwise its numeric behaviour lives entirely in
the ``python_funcify`` registry.
"""
from pytensor.link.python.dispatch.basic import python_funcify

fn = python_funcify(self, node=node)
results = fn(*inputs)
if len(output_storage) == 1:
output_storage[0][0] = results
else:
for output, value in zip(output_storage, results, strict=True):
output[0] = value

def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
"""Determine whether constant folding should be performed for the given node.
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.python.linker import PythonLinker
8 changes: 8 additions & 0 deletions pytensor/link/python/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# isort: off
from pytensor.link.python.dispatch.basic import python_funcify

# Load dispatch specializations
import pytensor.link.python.dispatch.blockwise
import pytensor.link.python.dispatch.linalg

# isort: on
99 changes: 99 additions & 0 deletions pytensor/link/python/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from functools import singledispatch

from pytensor.graph.fg import AbstractFunctionGraph
from pytensor.link.utils import fgraph_to_python


@singledispatch
def python_funcify(op, node=None, **kwargs):
"""Return a fast pure-Python implementation of ``op`` as a callable.

The callable takes the node's inputs positionally and returns its output (a
single value, or a tuple for multi-output nodes). Register a specialization
to override an `Op`'s ``perform`` with a faster numpy/scipy path on the
Python backend.

Unregistered ops raise `NotImplementedError`. The ``py`` (VM) linker catches
this and falls back to ``Op.make_thunk(impl="py")``; the whole-graph ``pyjit``
linker catches it and falls back to a ``perform`` wrapper.
"""
raise NotImplementedError(
f"No python_funcify implementation registered for {type(op).__name__}"
)


def _perform_wrapper(op, node):
"""Wrap an `Op`'s ``perform`` into a `python_funcify`-style callable."""
n_outputs = len(node.outputs)
single_output = n_outputs == 1

def perform(*inputs):
output_storage = [[None] for _ in range(n_outputs)]
op.perform(node, list(inputs), output_storage)
if single_output:
return output_storage[0][0]
return tuple(storage[0] for storage in output_storage)

return perform


def _funcify_or_perform(op, node=None, **kwargs):
try:
return python_funcify(op, node=node)
except NotImplementedError:
return _perform_wrapper(op, node)


@python_funcify.register(AbstractFunctionGraph)
def python_funcify_FunctionGraph(
fgraph, node=None, fgraph_name="python_funcified_fgraph", **kwargs
):
return fgraph_to_python(
fgraph,
op_conversion_fn=_funcify_or_perform,
fgraph_name=fgraph_name,
**kwargs,
)


def make_node_thunk_with_python_dispatch(
node, storage_map, compute_map, *, fallback, implementation
):
"""Build a per-node thunk, preferring a registered `python_funcify` implementation.

When `python_funcify` has a specialization for ``node.op``, its callable is
wrapped into a thunk that reads inputs from and writes outputs to
``storage_map``. Otherwise ``fallback`` (``Op.make_thunk``) is used, which
covers ``perform`` ops and lazy ops like ``IfElse`` unchanged.
"""
try:
fn = python_funcify(node.op, node=node)
except NotImplementedError:
return fallback(node, storage_map, compute_map, implementation)

return _wrap_callable_as_thunk(fn, node, storage_map, compute_map)


def _wrap_callable_as_thunk(fn, node, storage_map, compute_map):
input_storage = [storage_map[variable] for variable in node.inputs]
output_compute = [compute_map[variable] for variable in node.outputs]

if len(node.outputs) == 1:
[output] = (storage_map[variable] for variable in node.outputs)

def thunk(fn=fn, inputs=input_storage, output=output, compute=output_compute):
output[0] = fn(*(inp[0] for inp in inputs))
compute[0][0] = True
else:
output_storage = [storage_map[variable] for variable in node.outputs]

def thunk(
fn=fn, inputs=input_storage, outputs=output_storage, compute=output_compute
):
for output, value in zip(outputs, fn(*(inp[0] for inp in inputs))):
output[0] = value
for entry in compute:
entry[0] = True

thunk.lazy = False
return thunk
17 changes: 17 additions & 0 deletions pytensor/link/python/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.blockwise import Blockwise


@python_funcify.register(Blockwise)
def python_funcify_Blockwise(op, node=None, **kwargs):
core_node = op._create_dummy_core_node(
node.inputs, propagate_unbatched_core_inputs=True
)
# Raises NotImplementedError when the core Op has no dispatch, which makes the
# whole Blockwise fall back to its (vectorized) perform.
core_fn = python_funcify(op.core_op, node=core_node)

output_dtypes = [output.type.dtype for output in node.outputs]
return np.vectorize(core_fn, signature=op.signature, otypes=output_dtypes)
4 changes: 4 additions & 0 deletions pytensor/link/python/dispatch/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# isort: off
import pytensor.link.python.dispatch.linalg.decomposition
import pytensor.link.python.dispatch.linalg.solvers
# isort: on
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# isort: off
import pytensor.link.python.dispatch.linalg.decomposition.cholesky
import pytensor.link.python.dispatch.linalg.decomposition.qr
# isort: on
32 changes: 32 additions & 0 deletions pytensor/link/python/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky


@python_funcify.register(Cholesky)
def python_funcify_Cholesky(op, node=None, **kwargs):
lower = op.lower
overwrite_a = op.overwrite_a
(potrf,) = get_lapack_funcs(("potrf",), dtype=node.inputs[0].type.dtype)

def cholesky(x):
if x.size == 0:
return np.empty_like(x)

# potrf only honors overwrite_a for F-contiguous input; transpose a
# C-contiguous array to benefit from it.
c_contiguous_input = overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
x = x.T
factor, info = potrf(x, lower=not lower, overwrite_a=True, clean=True)
factor = factor.T
else:
factor, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)

if info != 0:
factor[...] = np.nan
return factor

return cholesky
54 changes: 54 additions & 0 deletions pytensor/link/python/dispatch/linalg/decomposition/qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.decomposition.qr import QR


@python_funcify.register(QR)
def python_funcify_QR(op, node=None, **kwargs):
mode = op.mode
pivoting = op.pivoting
overwrite_a = op.overwrite_a
call_and_get_lwork = op._call_and_get_lwork

def qr(x):
M, N = x.shape

if pivoting:
(geqp3,) = get_lapack_funcs(("geqp3",), (x,))
factor, jpvt, tau, *_ = call_and_get_lwork(
geqp3, x, lwork=-1, overwrite_a=overwrite_a
)
jpvt -= 1 # geqp3 returns 1-based indices
else:
(geqrf,) = get_lapack_funcs(("geqrf",), (x,))
factor, tau, *_ = call_and_get_lwork(
geqrf, x, lwork=-1, overwrite_a=overwrite_a
)

if mode not in ("economic", "raw") or M < N:
R = np.triu(factor)
else:
R = np.triu(factor[:N, :])

if mode == "r":
return (R, jpvt) if pivoting else R
if mode == "raw":
return (factor, tau, R, jpvt) if pivoting else (factor, tau, R)

(orgqr,) = get_lapack_funcs(("orgqr",), (factor,))
if M < N:
Q, *_ = call_and_get_lwork(
orgqr, factor[:, :M], tau, lwork=-1, overwrite_a=1
)
elif mode == "economic":
Q, *_ = call_and_get_lwork(orgqr, factor, tau, lwork=-1, overwrite_a=1)
else:
square = np.empty((M, M), dtype=factor.dtype.char)
square[:, :N] = factor
Q, *_ = call_and_get_lwork(orgqr, square, tau, lwork=-1, overwrite_a=1)

return (Q, R, jpvt) if pivoting else (Q, R)

return qr
3 changes: 3 additions & 0 deletions pytensor/link/python/dispatch/linalg/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# isort: off
import pytensor.link.python.dispatch.linalg.solvers.triangular
# isort: on
43 changes: 43 additions & 0 deletions pytensor/link/python/dispatch/linalg/solvers/triangular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.solvers.triangular import SolveTriangular


@python_funcify.register(SolveTriangular)
def python_funcify_SolveTriangular(op, node=None, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
overwrite_b = op.overwrite_b
(trtrs,) = get_lapack_funcs(("trtrs",), dtype=node.outputs[0].type.dtype)

def solve_triangular(A, b):
if b.size == 0:
return np.empty_like(b)

if A.flags["F_CONTIGUOUS"]:
x, info = trtrs(
A,
b,
overwrite_b=overwrite_b,
lower=lower,
trans=0,
unitdiag=unit_diagonal,
)
else:
# trtrs expects Fortran ordering, so solve the transposed system.
x, info = trtrs(
A.T,
b,
overwrite_b=overwrite_b,
lower=not lower,
trans=1,
unitdiag=unit_diagonal,
)

if info != 0:
x[...] = np.nan
return x

return solve_triangular
25 changes: 25 additions & 0 deletions pytensor/link/python/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pytensor.link.basic import JITLinker


class PythonLinker(JITLinker):
"""Compose a `FunctionGraph` into a single pure-Python function.

The whole graph is turned into one straight-line Python function by
`fgraph_to_python`, dispatching each `Op` through the `python_funcify`
registry (falling back to ``perform`` for unregistered ops). There is no
compilation step, so `jit_compile` is the identity.
"""

required_rewrites = ("minimum_compile", "py_only")
incompatible_rewrites = ("cxx_only",)

def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.python.dispatch.basic import python_funcify

return python_funcify(fgraph, **kwargs)

def jit_compile(self, fn):
return fn

def create_thunk_inputs(self, storage_map):
return [storage_map[n] for n in self.fgraph.inputs]
Loading
Loading