Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(include=["fast_run"]),
RewriteDatabaseQuery(include=["fast_run", "mlx"]),
)

FAST_COMPILE = Mode(
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.tensor_basic
import pytensor.link.mlx.dispatch.einsum
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
Expand Down
14 changes: 14 additions & 0 deletions pytensor/link/mlx/dispatch/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.einsum import Einsum


@mlx_funcify.register(Einsum)
def mlx_funcify_Einsum(op, **kwargs):
subscripts = op.subscripts

def einsum(*operands):
return mx.einsum(subscripts, *operands)

return einsum
1 change: 1 addition & 0 deletions pytensor/link/mlx/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class MLXLinker(JITLinker):
"local_careduce_fusion",
"inplace",
"scan_save_mem_prealloc",
"inline_einsum",
)

def __init__(self, use_compile=True, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def optimize_einsum_inner_graph(
return [new_out]


@register_specialize
@register_specialize("inline_einsum")
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.

This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
This allows the inner graph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op

Expand Down
54 changes: 54 additions & 0 deletions tests/link/mlx/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest

import pytensor.tensor as pt
from tests.link.mlx.test_basic import compare_mlx_and_py


mx = pytest.importorskip("mlx.core")


def test_mlx_einsum():
subscripts = "ij, jk, kl -> il"
x = np.random.rand(3, 5)
y = np.random.rand(5, 2)
z = np.random.rand(2, 4)

shapes = {
"x": (3, 5),
"y": (5, 2),
"z": (2, 4),
}
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
compare_mlx_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])


def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
y = np.random.rand(2, 5)

x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
compare_mlx_and_py([x_pt, y_pt], [out], [x, y])


def test_einsum_trace():
subscripts = "ii->"
x_pt = pt.matrix("x")
x_val = np.random.rand(5, 5)
out = pt.einsum(subscripts, x_pt)
compare_mlx_and_py([x_pt], [out], [x_val])


def test_einsum_batched_outer_product():
a = pt.matrix("a", dtype="float32")
b = pt.matrix("b", dtype="float32")
out = pt.einsum("bi,bj->bij", a, b)

a_val = np.random.normal(size=(5, 3)).astype("float32")
b_val = np.random.normal(size=(5, 2)).astype("float32")

compare_mlx_and_py([a, b], [out], [a_val, b_val])
Loading