From 54d2ea0b37c153ad61c8ec51758cf2e1342a5ed9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 18:30:33 -0500 Subject: [PATCH 1/2] Add einsum dispatch for mlx --- pytensor/compile/mode.py | 2 +- pytensor/link/mlx/dispatch/__init__.py | 1 + pytensor/link/mlx/dispatch/einsum.py | 24 ++++++++++++ pytensor/tensor/einsum.py | 36 +++++++++++++++++ pytensor/tensor/rewriting/einsum.py | 29 +++++++++++++- tests/link/mlx/test_einsum.py | 54 ++++++++++++++++++++++++++ 6 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 pytensor/link/mlx/dispatch/einsum.py create mode 100644 tests/link/mlx/test_einsum.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index aa0e1b3e28..26e47be278 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -467,7 +467,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): MLX = Mode( MLXLinker(), - RewriteDatabaseQuery(include=["fast_run"]), + RewriteDatabaseQuery(include=["fast_run", "mlx"]), ) FAST_COMPILE = Mode( diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index ac59f1809c..8b6447a1fd 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.tensor_basic +import pytensor.link.mlx.dispatch.einsum import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv import pytensor.link.mlx.dispatch.blockwise diff --git a/pytensor/link/mlx/dispatch/einsum.py b/pytensor/link/mlx/dispatch/einsum.py new file mode 100644 index 0000000000..7264455689 --- /dev/null +++ b/pytensor/link/mlx/dispatch/einsum.py @@ -0,0 +1,24 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.einsum import AbstractEinsum, Einsum + + +@mlx_funcify.register(Einsum) +def mlx_funcify_Einsum(op, **kwargs): + subscripts = op.subscripts + + def einsum(*operands): + return mx.einsum(subscripts, *operands) + + return einsum + + +@mlx_funcify.register(AbstractEinsum) +def mlx_funcify_AbstractEinsum(op, **kwargs): + subscripts = op.subscripts + + def einsum(*operands): + return mx.einsum(subscripts, *operands) + + return einsum diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index a6d5a358f1..c4b1aecc2d 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -13,6 +13,9 @@ from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from pytensor.compile.builders import OpFromGraph +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.scalar.basic import upcast from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, @@ -28,6 +31,7 @@ from pytensor.tensor.functional import vectorize from pytensor.tensor.math import and_, eq, tensordot from pytensor.tensor.shape import shape_padright +from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorVariable @@ -35,6 +39,38 @@ CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str] +class AbstractEinsum(Op): + """Thin einsum Op that holds only the subscript string. + + Unlike :class:`Einsum` (an ``OpFromGraph``), this Op has no inner graph. + Backends that natively support einsum (e.g. MLX, JAX) can dispatch it + directly to their own ``einsum`` implementation, avoiding decomposition + into lower-level ops that may not be supported. + + ``perform`` falls back to :func:`numpy.einsum` so the Op is always + executable on the Python backend. + """ + + __props__ = ("subscripts", "out_ndim") + + def __init__(self, subscripts: str, out_ndim: int): + self.subscripts = subscripts + self.out_ndim = out_ndim + super().__init__() + + def make_node(self, *operands): + operands = [as_tensor(op) for op in operands] + dtype = upcast(*[op.dtype for op in operands]) + out_type = TensorType(dtype=dtype, shape=(None,) * self.out_ndim) + return Apply(self, list(operands), [out_type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.einsum(self.subscripts, *inputs) + + def __str__(self): + return f"AbstractEinsum{{{self.subscripts}}}" + + class Einsum(OpFromGraph): """ Wrapper Op for Einsum graphs diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py index 5e9fe2d026..c2c1fc95f0 100644 --- a/pytensor/tensor/rewriting/einsum.py +++ b/pytensor/tensor/rewriting/einsum.py @@ -1,8 +1,9 @@ from typing import cast +from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph, node_rewriter -from pytensor.graph.rewriting.basic import copy_stack_trace -from pytensor.tensor.einsum import Einsum, einsum +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter +from pytensor.tensor.einsum import AbstractEinsum, Einsum, einsum from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.variable import TensorVariable @@ -51,3 +52,27 @@ def inline_optimized_einsum( return None return cast(list[TensorVariable], inline_ofg_node(node)) + + +@node_rewriter([Einsum]) +def einsum_to_abstract( + fgraph: FunctionGraph, node: Apply +) -> list[TensorVariable] | None: + """Replace ``Einsum`` with ``AbstractEinsum``. + + Backends that natively support einsum can dispatch ``AbstractEinsum`` to its native implementation, + rather than using the OpFromGraph defined by Pytensor. + """ + op: Einsum = node.op + out_ndim = node.outputs[0].ndim + new_out = AbstractEinsum(subscripts=op.subscripts, out_ndim=out_ndim)(*node.inputs) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +optdb.register( + "einsum_to_abstract", + dfs_rewriter(einsum_to_abstract), + "mlx", + position=1.9, # Before specialize (2.0) which inlines the Einsum OFG +) diff --git a/tests/link/mlx/test_einsum.py b/tests/link/mlx/test_einsum.py new file mode 100644 index 0000000000..712045d492 --- /dev/null +++ b/tests/link/mlx/test_einsum.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py + + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_einsum(): + subscripts = "ij, jk, kl -> il" + x = np.random.rand(3, 5) + y = np.random.rand(5, 2) + z = np.random.rand(2, 4) + + shapes = { + "x": (3, 5), + "y": (5, 2), + "z": (2, 4), + } + x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) + out = pt.einsum(subscripts, x_pt, y_pt, z_pt) + compare_mlx_and_py([x_pt, y_pt, z_pt], [out], [x, y, z]) + + +def test_ellipsis_einsum(): + subscripts = "...i,...i->..." + x = np.random.rand(2, 5) + y = np.random.rand(2, 5) + + x_pt = pt.tensor("x", shape=x.shape) + y_pt = pt.tensor("y", shape=y.shape) + out = pt.einsum(subscripts, x_pt, y_pt) + compare_mlx_and_py([x_pt, y_pt], [out], [x, y]) + + +def test_einsum_trace(): + subscripts = "ii->" + x_pt = pt.matrix("x") + x_val = np.random.rand(5, 5) + out = pt.einsum(subscripts, x_pt) + compare_mlx_and_py([x_pt], [out], [x_val]) + + +def test_einsum_batched_outer_product(): + a = pt.matrix("a", dtype="float32") + b = pt.matrix("b", dtype="float32") + out = pt.einsum("bi,bj->bij", a, b) + + a_val = np.random.normal(size=(5, 3)).astype("float32") + b_val = np.random.normal(size=(5, 2)).astype("float32") + + compare_mlx_and_py([a, b], [out], [a_val, b_val]) From b3be034591fe8a7f301a93ee31747437ff86a47e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Apr 2026 19:28:45 -0400 Subject: [PATCH 2/2] Remove AbstractEinsum and exclude inline_einsum_op rewrite in mlx mode --- pytensor/link/mlx/dispatch/einsum.py | 12 +--------- pytensor/link/mlx/linker.py | 1 + pytensor/tensor/einsum.py | 36 ---------------------------- pytensor/tensor/rewriting/einsum.py | 33 ++++--------------------- 4 files changed, 6 insertions(+), 76 deletions(-) diff --git a/pytensor/link/mlx/dispatch/einsum.py b/pytensor/link/mlx/dispatch/einsum.py index 7264455689..1961a10ab6 100644 --- a/pytensor/link/mlx/dispatch/einsum.py +++ b/pytensor/link/mlx/dispatch/einsum.py @@ -1,7 +1,7 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify -from pytensor.tensor.einsum import AbstractEinsum, Einsum +from pytensor.tensor.einsum import Einsum @mlx_funcify.register(Einsum) @@ -12,13 +12,3 @@ def einsum(*operands): return mx.einsum(subscripts, *operands) return einsum - - -@mlx_funcify.register(AbstractEinsum) -def mlx_funcify_AbstractEinsum(op, **kwargs): - subscripts = op.subscripts - - def einsum(*operands): - return mx.einsum(subscripts, *operands) - - return einsum diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index fea4c73d5c..6ec9619ce6 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -10,6 +10,7 @@ class MLXLinker(JITLinker): "local_careduce_fusion", "inplace", "scan_save_mem_prealloc", + "inline_einsum", ) def __init__(self, use_compile=True, *args, **kwargs): diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index c4b1aecc2d..a6d5a358f1 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -13,9 +13,6 @@ from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from pytensor.compile.builders import OpFromGraph -from pytensor.graph.basic import Apply -from pytensor.graph.op import Op -from pytensor.scalar.basic import upcast from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, @@ -31,7 +28,6 @@ from pytensor.tensor.functional import vectorize from pytensor.tensor.math import and_, eq, tensordot from pytensor.tensor.shape import shape_padright -from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorVariable @@ -39,38 +35,6 @@ CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str] -class AbstractEinsum(Op): - """Thin einsum Op that holds only the subscript string. - - Unlike :class:`Einsum` (an ``OpFromGraph``), this Op has no inner graph. - Backends that natively support einsum (e.g. MLX, JAX) can dispatch it - directly to their own ``einsum`` implementation, avoiding decomposition - into lower-level ops that may not be supported. - - ``perform`` falls back to :func:`numpy.einsum` so the Op is always - executable on the Python backend. - """ - - __props__ = ("subscripts", "out_ndim") - - def __init__(self, subscripts: str, out_ndim: int): - self.subscripts = subscripts - self.out_ndim = out_ndim - super().__init__() - - def make_node(self, *operands): - operands = [as_tensor(op) for op in operands] - dtype = upcast(*[op.dtype for op in operands]) - out_type = TensorType(dtype=dtype, shape=(None,) * self.out_ndim) - return Apply(self, list(operands), [out_type()]) - - def perform(self, node, inputs, output_storage): - output_storage[0][0] = np.einsum(self.subscripts, *inputs) - - def __str__(self): - return f"AbstractEinsum{{{self.subscripts}}}" - - class Einsum(OpFromGraph): """ Wrapper Op for Einsum graphs diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py index c2c1fc95f0..5ee8acc92a 100644 --- a/pytensor/tensor/rewriting/einsum.py +++ b/pytensor/tensor/rewriting/einsum.py @@ -1,9 +1,8 @@ from typing import cast -from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph, node_rewriter -from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter -from pytensor.tensor.einsum import AbstractEinsum, Einsum, einsum +from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.tensor.einsum import Einsum, einsum from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.variable import TensorVariable @@ -37,14 +36,14 @@ def optimize_einsum_inner_graph( return [new_out] -@register_specialize +@register_specialize("inline_einsum") @node_rewriter([Einsum]) def inline_optimized_einsum( fgraph: FunctionGraph, node: Apply ) -> list[TensorVariable] | None: """Inline einsums that are already optimized. - This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right. + This allows the inner graph to be optimized with the rest of the graph, now that we got ordering right. """ op: Einsum = node.op @@ -52,27 +51,3 @@ def inline_optimized_einsum( return None return cast(list[TensorVariable], inline_ofg_node(node)) - - -@node_rewriter([Einsum]) -def einsum_to_abstract( - fgraph: FunctionGraph, node: Apply -) -> list[TensorVariable] | None: - """Replace ``Einsum`` with ``AbstractEinsum``. - - Backends that natively support einsum can dispatch ``AbstractEinsum`` to its native implementation, - rather than using the OpFromGraph defined by Pytensor. - """ - op: Einsum = node.op - out_ndim = node.outputs[0].ndim - new_out = AbstractEinsum(subscripts=op.subscripts, out_ndim=out_ndim)(*node.inputs) - copy_stack_trace(node.outputs[0], new_out) - return [new_out] - - -optdb.register( - "einsum_to_abstract", - dfs_rewriter(einsum_to_abstract), - "mlx", - position=1.9, # Before specialize (2.0) which inlines the Einsum OFG -)