From 34b41ddb6e7a2b06da65a1a1fc458d0d01d85a10 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 19 Apr 2026 00:12:44 -0500 Subject: [PATCH 1/7] Add MultiDot op and rewrites for optimal contraction --- pytensor/tensor/linalg/__init__.py | 3 + pytensor/tensor/linalg/products.py | 70 +++- pytensor/tensor/rewriting/linalg/products.py | 313 +++++++++++++++++- tests/tensor/linalg/test_products.py | 39 ++- .../tensor/rewriting/linalg/test_products.py | 53 ++- 5 files changed, 464 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/linalg/__init__.py b/pytensor/tensor/linalg/__init__.py index 9ad282a4ad..48449db3fc 100644 --- a/pytensor/tensor/linalg/__init__.py +++ b/pytensor/tensor/linalg/__init__.py @@ -44,10 +44,12 @@ from pytensor.tensor.linalg.products import ( Expm, KroneckerProduct, + MultiDot, expm, kron, matrix_dot, matrix_power, + multi_dot, ) from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import ( @@ -108,6 +110,7 @@ "lu_solve", "matrix_dot", "matrix_power", + "multi_dot", "norm", "ordqz", "pinv", diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 1dd41fe51b..4d1b8bf037 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -128,7 +128,7 @@ def matrix_dot(*args): """ rval = args[0] for a in args[1:]: - rval = ptm.dot(rval, a) + rval = ptm.matmul(rval, a) return rval @@ -174,3 +174,71 @@ def matrix_power(M, n): result = z if result is None else ptm.dot(result, z) return result + + +class MultiDot(OpFromGraph): + """Wrapper Op for a sequence of matrix multiplications. + + Used as a target for rewrites to re-order matrix multiplications for better performance. + """ + + +def multi_dot(matrices): + """Compute the dot product of two or more matrices, selecting the fastest evaluation order automatically. + + The problem of optimal matrix multiplication ordering concerns how to place parenthesis in a sequence of matmuls + to make computation as efficient as possible. For a discussion, see the wikipedia page: https://en.wikipedia.org/wiki/Matrix_chain_multiplication + The following example from that page illustrates the point. Given 3 matrices A of shape (10, 30), B of shape + (30, 5), and C of shape (5, 60), the product X = A @ B @ C can be performed in two ways: (A @ B) @ C or A @ (B @ C). + + The first way requires 10*30*5 + 10*5*60 = 4500 multiplications, while the second way requires + 30*5*60 + 10*30*60 = 27000 multiplications. Thus, the first way is much more efficient, and multi_dot will + automatically select that way for you. + + The exact dynamic programming solution is used to find the optimal contraction path. + + Parameters + ---------- + matrices : list of TensorVariable + All inputs must be at least 2d, except for the first and last inputs, which can be 1d. If the first input is + 1d, it is treated as a row vector. If the last input is 1d, it is treated as a column vector. + + Returns + ------- + result : TensorVariable + The dot product of the input matrices, from left to right + """ + if len(matrices) < 2: + raise ValueError("multi_dot requires at least 2 matrices") + + matrices = [as_tensor_variable(a) for a in matrices] + + for i, a in enumerate(matrices): + if a.ndim < 1: + raise ValueError(f"multi_dot: array {i} is a scalar, expected at least 1-D") + if a.ndim == 1 and 0 < i < len(matrices) - 1: + raise ValueError( + f"multi_dot: interior array {i} must be at least 2-D, got 1-D" + ) + + squeeze_first = matrices[0].ndim == 1 + squeeze_last = matrices[-1].ndim == 1 + if squeeze_first: + matrices[0] = pt.expand_dims(matrices[0], axis=0) + if squeeze_last: + matrices[-1] = pt.expand_dims(matrices[-1], axis=1) + + if len(matrices) == 2: + result = ptm.matmul(matrices[0], matrices[1]) + else: + # Build naive inner graph, wrap in MultiDot + inner_inputs = [a.type.clone()(name=f"i{i}") for i, a in enumerate(matrices)] + inner_output = matrix_dot(*inner_inputs) + result = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) + + if squeeze_first: + result = result[0] if result.ndim == 2 else result[..., 0, :] + if squeeze_last: + result = result[..., 0] + + return result diff --git a/pytensor/tensor/rewriting/linalg/products.py b/pytensor/tensor/rewriting/linalg/products.py index 4a76658887..8dcc9eeb6a 100644 --- a/pytensor/tensor/rewriting/linalg/products.py +++ b/pytensor/tensor/rewriting/linalg/products.py @@ -1,19 +1,42 @@ +from typing import TYPE_CHECKING + +import numpy as np + from pytensor.graph.rewriting.basic import ( + GraphRewriter, copy_stack_trace, node_rewriter, ) from pytensor.tensor.basic import ExtractDiag, concatenate, diag from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg.constructors import BlockDiagonal -from pytensor.tensor.linalg.products import KroneckerProduct +from pytensor.tensor.linalg.products import ( + KroneckerProduct, + MultiDot, + matrix_dot, +) from pytensor.tensor.linalg.summary import det -from pytensor.tensor.math import outer, prod +from pytensor.tensor.math import Dot, matmul, outer, prod from pytensor.tensor.rewriting.basic import ( register_canonicalize, + register_specialize, register_stabilize, ) +if TYPE_CHECKING: + from pytensor.tensor.type import TensorVariable + + +def _is_dot_node(node): + """Check if a node is a Dot or Blockwise(Dot)""" + if isinstance(node.op, Dot): + return True + if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot): + return True + return False + + @register_canonicalize @node_rewriter([BlockDiagonal]) def fuse_blockdiagonal(fgraph, node): @@ -152,3 +175,289 @@ def det_of_kronecker(fgraph, node): [dets[i] ** (prod_sizes / sizes[i]) for i in range(2)], axis=-1 ) return [det_final] + + +class MultiDotAbsorber(GraphRewriter): + """Scan the entire fgraph for chains of 3+ matmul ops and wrap them in MultiDot. + + Matches both bare Dot (2-D) and Blockwise(Dot) (batched) nodes. + Only absorbs nodes whose output has a single client (the next matmul in the chain). + """ + + def apply(self, fgraph): + dot_nodes = [node for node in fgraph.toposort() if _is_dot_node(node)] + absorbed = set() + + for node in dot_nodes: + if node in absorbed: + continue + + chain_nodes = [node] + current = node + + # Walk backward: the left input might be another matmul + while True: + left_input = current.inputs[0] + if ( + left_input.owner + and _is_dot_node(left_input.owner) + and left_input.owner not in absorbed + and len(fgraph.clients[left_input]) == 1 + ): + current = left_input.owner + chain_nodes.insert(0, current) + else: + break + + # Walk forward: the output might feed into another matmul as left input + current = node + while True: + out = current.outputs[0] + clients = fgraph.clients[out] + if ( + len(clients) == 1 + and clients[0][0] != "output" + and _is_dot_node(clients[0][0]) + and clients[0][0].inputs[0] is out + and clients[0][0] not in absorbed + ): + current = clients[0][0] + chain_nodes.append(current) + else: + break + + if len(chain_nodes) < 2: + continue + + # Extract matrices: first node's left input, then each node's right input + matrices = [ + chain_nodes[0].inputs[0], + *(node.inputs[1] for node in chain_nodes), + ] + absorbed.update(chain_nodes) + + inner_inputs = [ + m.type.clone()(name=f"i{i}") for i, m in enumerate(matrices) + ] + + # Naive chained matmuls -- optimization is done later by a separate rewrite + inner_output = matrix_dot(*inner_inputs) + multi_dot_op = MultiDot(inputs=inner_inputs, outputs=[inner_output]) + new_out = multi_dot_op(*matrices) + copy_stack_trace(chain_nodes[-1].outputs[0], new_out) + fgraph.replace( + chain_nodes[-1].outputs[0], + new_out, + reason="dot_chain_to_multi_dot", + ) + + +multi_dot_absorber = MultiDotAbsorber() +multi_dot_absorber.name = "multi_dot_absorber" +register_canonicalize(multi_dot_absorber, "multi_dot", name="multi_dot_absorber") + + +@register_canonicalize("multi_dot") +@node_rewriter([MultiDot]) +def fuse_multi_dot_operands(fgraph, node): + """Fuse matmul(MultiDot(...), X) or matmul(X, MultiDot(...)) into a larger MultiDot. + + Also handles matmul(MultiDot(...), MultiDot(...)) by merging both. + """ + [multi_out] = node.outputs + clients = fgraph.clients[multi_out] + + for client_node, _ in clients: + if client_node == "output" or not _is_dot_node(client_node): + continue + + left, right = client_node.inputs + + # Only absorb a MultiDot whose output has a single client (this Dot), + # otherwise we'd duplicate the computation of its inner chain. + left_is_multi = ( + left.owner + and isinstance(left.owner.op, MultiDot) + and len(fgraph.clients[left]) == 1 + ) + right_is_multi = ( + right.owner + and isinstance(right.owner.op, MultiDot) + and len(fgraph.clients[right]) == 1 + ) + + if not left_is_multi and not right_is_multi: + continue + + matrices = [] + if left_is_multi: + matrices.extend(left.owner.inputs) + else: + matrices.append(left) + if right_is_multi: + matrices.extend(right.owner.inputs) + else: + matrices.append(right) + + inner_inputs = [m.type.clone()(name=f"i{i}") for i, m in enumerate(matrices)] + inner_output = matrix_dot(*inner_inputs) + new_out = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) + + copy_stack_trace(client_node.outputs[0], new_out) + + return {client_node.outputs[0]: new_out} + + return None + + +@register_canonicalize("multi_dot") +@node_rewriter([MultiDot]) +def flatten_nested_multi_dot(fgraph, node): + """Flatten nested MultiDot inputs into a single MultiDot. Nested MultiDots arise from fuse_multi_dot_operands, + when graphs like A @ B @ MultiDot(C, D, E, F) -> MultiDot(A, B, MultiDot(C, D, E, F)). This rewrite does a final + clean up: MultiDot([A, B, MultiDot([C, D, E, F])]) -> MultiDot([A, B, C, D, E, F]) + """ + inputs = node.inputs + has_nested = any( + inp.owner + and isinstance(inp.owner.op, MultiDot) + and len(fgraph.clients[inp]) == 1 + for inp in inputs + ) + + if not has_nested: + return None + + # Flatten inputs + matrices = [] + for inp in inputs: + if ( + inp.owner + and isinstance(inp.owner.op, MultiDot) + and len(fgraph.clients[inp]) == 1 + ): + matrices.extend(inp.owner.inputs) + else: + matrices.append(inp) + + inner_inputs = [] + for m in matrices: + if m in inner_inputs: + # If a matrix is repeated, we need to copy it to avoid issues with the OpFromGraph + inner_inputs.append(m.type(name=m.name)) + else: + inner_inputs.append(m) + + inner_output = matrix_dot(*inner_inputs) + new_out = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) + + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +def matrix_chain_split_points(dimensions: list[int]) -> np.ndarray: + """Return optimal split points for matrix-chain multiplication. + + Parameters + ---------- + dimensions : list[int] + Matrix dimensions encoded as a list of integers, where the i-th matrix has shape + (dimensions[i], dimensions[i+1]) + + Returns + ------- + split_at : np.ndarray + Integer array of shape (n, n), where n = len(dimensions) - 1. + + ``split_at[i, j]`` gives the index ``k`` such that the optimal way + to parenthesize the subchain A_i ... A_j is: + + (A_i ... A_k) @ (A_{k+1} ... A_j) + + Entries on and below the diagonal are unused. + """ + num_matrices = len(dimensions) - 1 + + min_cost = np.zeros((num_matrices, num_matrices), dtype=np.float64) + split_at = np.zeros((num_matrices, num_matrices), dtype=np.intp) + + # subchain_length is the number of matrices in the subchain minus 1: + # 1 means pairs (Ai..A{i+1}), 2 means triples, etc. + for subchain_length in range(1, num_matrices): + for start in range(num_matrices - subchain_length): + end = start + subchain_length + + best_cost = np.inf + best_split = -1 + + for split in range(start, end): + candidate_cost = ( + min_cost[start, split] + + min_cost[split + 1, end] + + dimensions[start] * dimensions[split + 1] * dimensions[end + 1] + ) + + if candidate_cost < best_cost: + best_cost = candidate_cost + best_split = split + + min_cost[start, end] = best_cost + split_at[start, end] = best_split + + return split_at + + +def _build_optimal_matmul_tree( + matrices: list[TensorVariable], split_at: np.ndarray, start: int, end: int +): + """Return the optimal parenthesization of ``matrices[start:end+1]``. + + Parameters + ---------- + matrices : list of TensorVariable + Sequence of matrix expressions. + split_at : np.ndarray + DP split table where ``split_at[a, b]`` gives the index at which + to split the subchain ``matrices[a:b+1]``. + start, end : int + Inclusive bounds of the subchain to reconstruct. + """ + if start == end: + return matrices[start] + + split = split_at[start, end] + return matmul( + _build_optimal_matmul_tree(matrices, split_at, start, split), + _build_optimal_matmul_tree(matrices, split_at, split + 1, end), + ) + + +@register_specialize("multi_dot") +@node_rewriter([MultiDot]) +def lower_multi_dot(fgraph, node): + """Lower MultiDot to an optimally-ordered sequence of matmul ops. + + Core dimensions of each input are assumed to be the last two. All core dimensions must be statically known for + the optimization to fire. If anything is unknown, fall back to left-to-right ordering. + """ + matrices = node.inputs + n = len(matrices) + + # Core shapes are encoded in matrix dims so that input_1.core_shape = matrix_dims[[0, 1]], + # input_2.core_shape = matrix_dims[[1, 2]], and so on. + matrix_dimensions = [ + matrices[0].type.shape[-2], + *[a.type.shape[-1] for a in matrices], + ] + if any(d is None for d in matrix_dimensions): + new_out = matrix_dot(*matrices) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + split_points = matrix_chain_split_points(matrix_dimensions) + new_out = _build_optimal_matmul_tree( + list(matrices), split_points, start=0, end=n - 1 + ) + copy_stack_trace(node.outputs[0], new_out) + + return [new_out] diff --git a/tests/tensor/linalg/test_products.py b/tests/tensor/linalg/test_products.py index cf601a3090..6b81a9e0e7 100644 --- a/tests/tensor/linalg/test_products.py +++ b/tests/tensor/linalg/test_products.py @@ -10,9 +10,9 @@ kron, matrix_dot, matrix_power, + multi_dot, pinv, ) -from pytensor.tensor.math import _allclose from pytensor.tensor.type import matrix, tensor, vector from tests import unittest_tools as utt @@ -20,19 +20,14 @@ def test_matrix_dot(): rng = np.random.default_rng(utt.fetch_seed()) n = rng.integers(4) + 2 - rs = [] - xs = [] - for k in range(n): - rs += [rng.standard_normal((4, 4)).astype(config.floatX)] - xs += [matrix()] + rs = [rng.normal(size=(4, 4)).astype(config.floatX) for _ in range(n)] + xs = [matrix() for _ in range(n)] sol = matrix_dot(*xs) pytensor_sol = function(xs, sol)(*rs) - numpy_sol = rs[0] - for r in rs[1:]: - numpy_sol = np.dot(numpy_sol, r) + numpy_sol = np.linalg.multi_dot(rs) - assert _allclose(numpy_sol, pytensor_sol) + np.testing.assert_allclose(numpy_sol, pytensor_sol) class TestMatrixPower: @@ -160,3 +155,27 @@ def test_expm_grad(mode): raise ValueError(f"Invalid mode: {mode}") utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5) + + +def test_multi_dot(): + rng = np.random.default_rng(utt.fetch_seed()) + + shapes_2d = [(10, 20), (20, 5), (5, 30), (30, 3)] + arrays_np = [rng.normal(size=s).astype(config.floatX) for s in shapes_2d] + arrays_pt = [matrix(f"M{i}", shape=s) for i, s in enumerate(shapes_2d)] + out = multi_dot(arrays_pt) + f = function(arrays_pt, out) + np.testing.assert_allclose(f(*arrays_np), np.linalg.multi_dot(arrays_np), rtol=1e-5) + + shapes_3d = [(7, 10, 20), (7, 20, 5), (7, 5, 30)] + arrays_np_3d = [rng.normal(size=s).astype(config.floatX) for s in shapes_3d] + arrays_pt_3d = [ + pytensor.tensor.tensor3(f"B{i}", shape=s) for i, s in enumerate(shapes_3d) + ] + out_3d = multi_dot(arrays_pt_3d) + f_3d = function(arrays_pt_3d, out_3d) + np.testing.assert_allclose( + f_3d(*arrays_np_3d), + arrays_np_3d[0] @ arrays_np_3d[1] @ arrays_np_3d[2], + rtol=1e-5, + ) diff --git a/tests/tensor/rewriting/linalg/test_products.py b/tests/tensor/rewriting/linalg/test_products.py index 5884e2e97e..8205a040ca 100644 --- a/tests/tensor/rewriting/linalg/test_products.py +++ b/tests/tensor/rewriting/linalg/test_products.py @@ -8,7 +8,9 @@ from pytensor.graph import FunctionGraph, ancestors from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor.linalg.constructors import BlockDiagonal -from pytensor.tensor.linalg.products import KroneckerProduct +from pytensor.tensor.linalg.products import KroneckerProduct, MultiDot +from pytensor.tensor.math import Dot +from tests.unittest_tools import assert_equal_computations def test_nested_blockdiag_fusion(): @@ -236,3 +238,52 @@ def test_slogdet_kronecker_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +class TestMultiDotRewrite: + def test_dot_chain_absorbed(self): + """A @ B @ C @ D should become MultiDot after canonicalize.""" + A = pt.matrix("A", shape=(10, 20)) + B = pt.matrix("B", shape=(20, 5)) + C = pt.matrix("C", shape=(5, 30)) + D = pt.matrix("D", shape=(30, 3)) + out = A @ B @ C @ D + + rewritten = rewrite_graph(out, include=("canonicalize",)) + assert rewritten.owner and isinstance(rewritten.owner.op, MultiDot) + assert len(rewritten.owner.inputs) == 4 + + def test_optimal_ordering(self): + """Chain ordering should minimize FLOPs, not use naive left-to-right.""" + # (100x2) @ (2x100) @ (100x100): optimal is A @ (B @ C) + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + + f = function([A, B, C], A @ B @ C) + dot_nodes = [n for n in f.maker.fgraph.toposort() if isinstance(n.op, Dot)] + # First dot should be B @ C (2x100), not A @ B (100x100) + assert dot_nodes[0].outputs[0].type.shape == (2, 100) + + def test_fuse_multi_dot_operands(self): + A = pt.matrix("A", shape=(10, 20)) + B = pt.matrix("B", shape=(20, 5)) + C = pt.matrix("C", shape=(5, 30)) + D = pt.matrix("D", shape=(30, 8)) + E = pt.matrix("E", shape=(8, 3)) + F = pt.matrix("F", shape=(3, 12)) + + out = pt.linalg.multi_dot([A, B, C]) @ pt.linalg.multi_dot([D, E, F]) + rewritten = rewrite_graph(out, include=("canonicalize",)) + expected = pt.linalg.multi_dot([A, B, C, D, E, F]) + assert_equal_computations([rewritten], [expected]) + + out = A @ B @ pt.linalg.multi_dot([C, D, E, F]) + rewritten = rewrite_graph(out, include=("canonicalize",)) + expected = pt.linalg.multi_dot([A, B, C, D, E, F]) + assert_equal_computations([rewritten], [expected]) + + out = pt.linalg.multi_dot([A, B, C, D]) @ E @ F + rewritten = rewrite_graph(out, include=("canonicalize",)) + expected = pt.linalg.multi_dot([A, B, C, D, E, F]) + assert_equal_computations([rewritten], [expected]) From 852ab79ae7b77d34f76ec3bc68fa26655b4192f5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 19 Apr 2026 00:28:28 -0500 Subject: [PATCH 2/7] mypy --- pytensor/tensor/rewriting/linalg/products.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg/products.py b/pytensor/tensor/rewriting/linalg/products.py index 8dcc9eeb6a..3e53b5d763 100644 --- a/pytensor/tensor/rewriting/linalg/products.py +++ b/pytensor/tensor/rewriting/linalg/products.py @@ -254,7 +254,7 @@ def apply(self, fgraph): multi_dot_absorber = MultiDotAbsorber() multi_dot_absorber.name = "multi_dot_absorber" -register_canonicalize(multi_dot_absorber, "multi_dot", name="multi_dot_absorber") +register_canonicalize(multi_dot_absorber, "multi_dot", name="multi_dot_absorber") # type: ignore[arg-type] @register_canonicalize("multi_dot") From 39c0f8b2876f7049ddd6afc6701249c372007d34 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 30 Apr 2026 20:59:26 -0500 Subject: [PATCH 3/7] back to square 1 --- pytensor/tensor/linalg/__init__.py | 3 - pytensor/tensor/linalg/products.py | 68 ---- pytensor/tensor/rewriting/linalg/products.py | 318 +----------------- tests/tensor/linalg/test_products.py | 25 -- .../tensor/rewriting/linalg/test_products.py | 53 +-- 5 files changed, 4 insertions(+), 463 deletions(-) diff --git a/pytensor/tensor/linalg/__init__.py b/pytensor/tensor/linalg/__init__.py index 48449db3fc..9ad282a4ad 100644 --- a/pytensor/tensor/linalg/__init__.py +++ b/pytensor/tensor/linalg/__init__.py @@ -44,12 +44,10 @@ from pytensor.tensor.linalg.products import ( Expm, KroneckerProduct, - MultiDot, expm, kron, matrix_dot, matrix_power, - multi_dot, ) from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import ( @@ -110,7 +108,6 @@ "lu_solve", "matrix_dot", "matrix_power", - "multi_dot", "norm", "ordqz", "pinv", diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 4d1b8bf037..23c773bff3 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -174,71 +174,3 @@ def matrix_power(M, n): result = z if result is None else ptm.dot(result, z) return result - - -class MultiDot(OpFromGraph): - """Wrapper Op for a sequence of matrix multiplications. - - Used as a target for rewrites to re-order matrix multiplications for better performance. - """ - - -def multi_dot(matrices): - """Compute the dot product of two or more matrices, selecting the fastest evaluation order automatically. - - The problem of optimal matrix multiplication ordering concerns how to place parenthesis in a sequence of matmuls - to make computation as efficient as possible. For a discussion, see the wikipedia page: https://en.wikipedia.org/wiki/Matrix_chain_multiplication - The following example from that page illustrates the point. Given 3 matrices A of shape (10, 30), B of shape - (30, 5), and C of shape (5, 60), the product X = A @ B @ C can be performed in two ways: (A @ B) @ C or A @ (B @ C). - - The first way requires 10*30*5 + 10*5*60 = 4500 multiplications, while the second way requires - 30*5*60 + 10*30*60 = 27000 multiplications. Thus, the first way is much more efficient, and multi_dot will - automatically select that way for you. - - The exact dynamic programming solution is used to find the optimal contraction path. - - Parameters - ---------- - matrices : list of TensorVariable - All inputs must be at least 2d, except for the first and last inputs, which can be 1d. If the first input is - 1d, it is treated as a row vector. If the last input is 1d, it is treated as a column vector. - - Returns - ------- - result : TensorVariable - The dot product of the input matrices, from left to right - """ - if len(matrices) < 2: - raise ValueError("multi_dot requires at least 2 matrices") - - matrices = [as_tensor_variable(a) for a in matrices] - - for i, a in enumerate(matrices): - if a.ndim < 1: - raise ValueError(f"multi_dot: array {i} is a scalar, expected at least 1-D") - if a.ndim == 1 and 0 < i < len(matrices) - 1: - raise ValueError( - f"multi_dot: interior array {i} must be at least 2-D, got 1-D" - ) - - squeeze_first = matrices[0].ndim == 1 - squeeze_last = matrices[-1].ndim == 1 - if squeeze_first: - matrices[0] = pt.expand_dims(matrices[0], axis=0) - if squeeze_last: - matrices[-1] = pt.expand_dims(matrices[-1], axis=1) - - if len(matrices) == 2: - result = ptm.matmul(matrices[0], matrices[1]) - else: - # Build naive inner graph, wrap in MultiDot - inner_inputs = [a.type.clone()(name=f"i{i}") for i, a in enumerate(matrices)] - inner_output = matrix_dot(*inner_inputs) - result = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) - - if squeeze_first: - result = result[0] if result.ndim == 2 else result[..., 0, :] - if squeeze_last: - result = result[..., 0] - - return result diff --git a/pytensor/tensor/rewriting/linalg/products.py b/pytensor/tensor/rewriting/linalg/products.py index 3e53b5d763..1f69da680a 100644 --- a/pytensor/tensor/rewriting/linalg/products.py +++ b/pytensor/tensor/rewriting/linalg/products.py @@ -1,42 +1,16 @@ -from typing import TYPE_CHECKING - -import numpy as np - -from pytensor.graph.rewriting.basic import ( - GraphRewriter, - copy_stack_trace, - node_rewriter, -) +from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import ExtractDiag, concatenate, diag from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg.constructors import BlockDiagonal -from pytensor.tensor.linalg.products import ( - KroneckerProduct, - MultiDot, - matrix_dot, -) +from pytensor.tensor.linalg.products import KroneckerProduct from pytensor.tensor.linalg.summary import det -from pytensor.tensor.math import Dot, matmul, outer, prod +from pytensor.tensor.math import outer, prod from pytensor.tensor.rewriting.basic import ( register_canonicalize, - register_specialize, register_stabilize, ) -if TYPE_CHECKING: - from pytensor.tensor.type import TensorVariable - - -def _is_dot_node(node): - """Check if a node is a Dot or Blockwise(Dot)""" - if isinstance(node.op, Dot): - return True - if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot): - return True - return False - - @register_canonicalize @node_rewriter([BlockDiagonal]) def fuse_blockdiagonal(fgraph, node): @@ -175,289 +149,3 @@ def det_of_kronecker(fgraph, node): [dets[i] ** (prod_sizes / sizes[i]) for i in range(2)], axis=-1 ) return [det_final] - - -class MultiDotAbsorber(GraphRewriter): - """Scan the entire fgraph for chains of 3+ matmul ops and wrap them in MultiDot. - - Matches both bare Dot (2-D) and Blockwise(Dot) (batched) nodes. - Only absorbs nodes whose output has a single client (the next matmul in the chain). - """ - - def apply(self, fgraph): - dot_nodes = [node for node in fgraph.toposort() if _is_dot_node(node)] - absorbed = set() - - for node in dot_nodes: - if node in absorbed: - continue - - chain_nodes = [node] - current = node - - # Walk backward: the left input might be another matmul - while True: - left_input = current.inputs[0] - if ( - left_input.owner - and _is_dot_node(left_input.owner) - and left_input.owner not in absorbed - and len(fgraph.clients[left_input]) == 1 - ): - current = left_input.owner - chain_nodes.insert(0, current) - else: - break - - # Walk forward: the output might feed into another matmul as left input - current = node - while True: - out = current.outputs[0] - clients = fgraph.clients[out] - if ( - len(clients) == 1 - and clients[0][0] != "output" - and _is_dot_node(clients[0][0]) - and clients[0][0].inputs[0] is out - and clients[0][0] not in absorbed - ): - current = clients[0][0] - chain_nodes.append(current) - else: - break - - if len(chain_nodes) < 2: - continue - - # Extract matrices: first node's left input, then each node's right input - matrices = [ - chain_nodes[0].inputs[0], - *(node.inputs[1] for node in chain_nodes), - ] - absorbed.update(chain_nodes) - - inner_inputs = [ - m.type.clone()(name=f"i{i}") for i, m in enumerate(matrices) - ] - - # Naive chained matmuls -- optimization is done later by a separate rewrite - inner_output = matrix_dot(*inner_inputs) - multi_dot_op = MultiDot(inputs=inner_inputs, outputs=[inner_output]) - new_out = multi_dot_op(*matrices) - copy_stack_trace(chain_nodes[-1].outputs[0], new_out) - fgraph.replace( - chain_nodes[-1].outputs[0], - new_out, - reason="dot_chain_to_multi_dot", - ) - - -multi_dot_absorber = MultiDotAbsorber() -multi_dot_absorber.name = "multi_dot_absorber" -register_canonicalize(multi_dot_absorber, "multi_dot", name="multi_dot_absorber") # type: ignore[arg-type] - - -@register_canonicalize("multi_dot") -@node_rewriter([MultiDot]) -def fuse_multi_dot_operands(fgraph, node): - """Fuse matmul(MultiDot(...), X) or matmul(X, MultiDot(...)) into a larger MultiDot. - - Also handles matmul(MultiDot(...), MultiDot(...)) by merging both. - """ - [multi_out] = node.outputs - clients = fgraph.clients[multi_out] - - for client_node, _ in clients: - if client_node == "output" or not _is_dot_node(client_node): - continue - - left, right = client_node.inputs - - # Only absorb a MultiDot whose output has a single client (this Dot), - # otherwise we'd duplicate the computation of its inner chain. - left_is_multi = ( - left.owner - and isinstance(left.owner.op, MultiDot) - and len(fgraph.clients[left]) == 1 - ) - right_is_multi = ( - right.owner - and isinstance(right.owner.op, MultiDot) - and len(fgraph.clients[right]) == 1 - ) - - if not left_is_multi and not right_is_multi: - continue - - matrices = [] - if left_is_multi: - matrices.extend(left.owner.inputs) - else: - matrices.append(left) - if right_is_multi: - matrices.extend(right.owner.inputs) - else: - matrices.append(right) - - inner_inputs = [m.type.clone()(name=f"i{i}") for i, m in enumerate(matrices)] - inner_output = matrix_dot(*inner_inputs) - new_out = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) - - copy_stack_trace(client_node.outputs[0], new_out) - - return {client_node.outputs[0]: new_out} - - return None - - -@register_canonicalize("multi_dot") -@node_rewriter([MultiDot]) -def flatten_nested_multi_dot(fgraph, node): - """Flatten nested MultiDot inputs into a single MultiDot. Nested MultiDots arise from fuse_multi_dot_operands, - when graphs like A @ B @ MultiDot(C, D, E, F) -> MultiDot(A, B, MultiDot(C, D, E, F)). This rewrite does a final - clean up: MultiDot([A, B, MultiDot([C, D, E, F])]) -> MultiDot([A, B, C, D, E, F]) - """ - inputs = node.inputs - has_nested = any( - inp.owner - and isinstance(inp.owner.op, MultiDot) - and len(fgraph.clients[inp]) == 1 - for inp in inputs - ) - - if not has_nested: - return None - - # Flatten inputs - matrices = [] - for inp in inputs: - if ( - inp.owner - and isinstance(inp.owner.op, MultiDot) - and len(fgraph.clients[inp]) == 1 - ): - matrices.extend(inp.owner.inputs) - else: - matrices.append(inp) - - inner_inputs = [] - for m in matrices: - if m in inner_inputs: - # If a matrix is repeated, we need to copy it to avoid issues with the OpFromGraph - inner_inputs.append(m.type(name=m.name)) - else: - inner_inputs.append(m) - - inner_output = matrix_dot(*inner_inputs) - new_out = MultiDot(inputs=inner_inputs, outputs=[inner_output])(*matrices) - - copy_stack_trace(node.outputs[0], new_out) - return [new_out] - - -def matrix_chain_split_points(dimensions: list[int]) -> np.ndarray: - """Return optimal split points for matrix-chain multiplication. - - Parameters - ---------- - dimensions : list[int] - Matrix dimensions encoded as a list of integers, where the i-th matrix has shape - (dimensions[i], dimensions[i+1]) - - Returns - ------- - split_at : np.ndarray - Integer array of shape (n, n), where n = len(dimensions) - 1. - - ``split_at[i, j]`` gives the index ``k`` such that the optimal way - to parenthesize the subchain A_i ... A_j is: - - (A_i ... A_k) @ (A_{k+1} ... A_j) - - Entries on and below the diagonal are unused. - """ - num_matrices = len(dimensions) - 1 - - min_cost = np.zeros((num_matrices, num_matrices), dtype=np.float64) - split_at = np.zeros((num_matrices, num_matrices), dtype=np.intp) - - # subchain_length is the number of matrices in the subchain minus 1: - # 1 means pairs (Ai..A{i+1}), 2 means triples, etc. - for subchain_length in range(1, num_matrices): - for start in range(num_matrices - subchain_length): - end = start + subchain_length - - best_cost = np.inf - best_split = -1 - - for split in range(start, end): - candidate_cost = ( - min_cost[start, split] - + min_cost[split + 1, end] - + dimensions[start] * dimensions[split + 1] * dimensions[end + 1] - ) - - if candidate_cost < best_cost: - best_cost = candidate_cost - best_split = split - - min_cost[start, end] = best_cost - split_at[start, end] = best_split - - return split_at - - -def _build_optimal_matmul_tree( - matrices: list[TensorVariable], split_at: np.ndarray, start: int, end: int -): - """Return the optimal parenthesization of ``matrices[start:end+1]``. - - Parameters - ---------- - matrices : list of TensorVariable - Sequence of matrix expressions. - split_at : np.ndarray - DP split table where ``split_at[a, b]`` gives the index at which - to split the subchain ``matrices[a:b+1]``. - start, end : int - Inclusive bounds of the subchain to reconstruct. - """ - if start == end: - return matrices[start] - - split = split_at[start, end] - return matmul( - _build_optimal_matmul_tree(matrices, split_at, start, split), - _build_optimal_matmul_tree(matrices, split_at, split + 1, end), - ) - - -@register_specialize("multi_dot") -@node_rewriter([MultiDot]) -def lower_multi_dot(fgraph, node): - """Lower MultiDot to an optimally-ordered sequence of matmul ops. - - Core dimensions of each input are assumed to be the last two. All core dimensions must be statically known for - the optimization to fire. If anything is unknown, fall back to left-to-right ordering. - """ - matrices = node.inputs - n = len(matrices) - - # Core shapes are encoded in matrix dims so that input_1.core_shape = matrix_dims[[0, 1]], - # input_2.core_shape = matrix_dims[[1, 2]], and so on. - matrix_dimensions = [ - matrices[0].type.shape[-2], - *[a.type.shape[-1] for a in matrices], - ] - if any(d is None for d in matrix_dimensions): - new_out = matrix_dot(*matrices) - copy_stack_trace(node.outputs[0], new_out) - return [new_out] - - split_points = matrix_chain_split_points(matrix_dimensions) - new_out = _build_optimal_matmul_tree( - list(matrices), split_points, start=0, end=n - 1 - ) - copy_stack_trace(node.outputs[0], new_out) - - return [new_out] diff --git a/tests/tensor/linalg/test_products.py b/tests/tensor/linalg/test_products.py index 6b81a9e0e7..2a0ee972dd 100644 --- a/tests/tensor/linalg/test_products.py +++ b/tests/tensor/linalg/test_products.py @@ -10,7 +10,6 @@ kron, matrix_dot, matrix_power, - multi_dot, pinv, ) from pytensor.tensor.type import matrix, tensor, vector @@ -155,27 +154,3 @@ def test_expm_grad(mode): raise ValueError(f"Invalid mode: {mode}") utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5) - - -def test_multi_dot(): - rng = np.random.default_rng(utt.fetch_seed()) - - shapes_2d = [(10, 20), (20, 5), (5, 30), (30, 3)] - arrays_np = [rng.normal(size=s).astype(config.floatX) for s in shapes_2d] - arrays_pt = [matrix(f"M{i}", shape=s) for i, s in enumerate(shapes_2d)] - out = multi_dot(arrays_pt) - f = function(arrays_pt, out) - np.testing.assert_allclose(f(*arrays_np), np.linalg.multi_dot(arrays_np), rtol=1e-5) - - shapes_3d = [(7, 10, 20), (7, 20, 5), (7, 5, 30)] - arrays_np_3d = [rng.normal(size=s).astype(config.floatX) for s in shapes_3d] - arrays_pt_3d = [ - pytensor.tensor.tensor3(f"B{i}", shape=s) for i, s in enumerate(shapes_3d) - ] - out_3d = multi_dot(arrays_pt_3d) - f_3d = function(arrays_pt_3d, out_3d) - np.testing.assert_allclose( - f_3d(*arrays_np_3d), - arrays_np_3d[0] @ arrays_np_3d[1] @ arrays_np_3d[2], - rtol=1e-5, - ) diff --git a/tests/tensor/rewriting/linalg/test_products.py b/tests/tensor/rewriting/linalg/test_products.py index 8205a040ca..5884e2e97e 100644 --- a/tests/tensor/rewriting/linalg/test_products.py +++ b/tests/tensor/rewriting/linalg/test_products.py @@ -8,9 +8,7 @@ from pytensor.graph import FunctionGraph, ancestors from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor.linalg.constructors import BlockDiagonal -from pytensor.tensor.linalg.products import KroneckerProduct, MultiDot -from pytensor.tensor.math import Dot -from tests.unittest_tools import assert_equal_computations +from pytensor.tensor.linalg.products import KroneckerProduct def test_nested_blockdiag_fusion(): @@ -238,52 +236,3 @@ def test_slogdet_kronecker_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) - - -class TestMultiDotRewrite: - def test_dot_chain_absorbed(self): - """A @ B @ C @ D should become MultiDot after canonicalize.""" - A = pt.matrix("A", shape=(10, 20)) - B = pt.matrix("B", shape=(20, 5)) - C = pt.matrix("C", shape=(5, 30)) - D = pt.matrix("D", shape=(30, 3)) - out = A @ B @ C @ D - - rewritten = rewrite_graph(out, include=("canonicalize",)) - assert rewritten.owner and isinstance(rewritten.owner.op, MultiDot) - assert len(rewritten.owner.inputs) == 4 - - def test_optimal_ordering(self): - """Chain ordering should minimize FLOPs, not use naive left-to-right.""" - # (100x2) @ (2x100) @ (100x100): optimal is A @ (B @ C) - A = pt.matrix("A", shape=(100, 2)) - B = pt.matrix("B", shape=(2, 100)) - C = pt.matrix("C", shape=(100, 100)) - - f = function([A, B, C], A @ B @ C) - dot_nodes = [n for n in f.maker.fgraph.toposort() if isinstance(n.op, Dot)] - # First dot should be B @ C (2x100), not A @ B (100x100) - assert dot_nodes[0].outputs[0].type.shape == (2, 100) - - def test_fuse_multi_dot_operands(self): - A = pt.matrix("A", shape=(10, 20)) - B = pt.matrix("B", shape=(20, 5)) - C = pt.matrix("C", shape=(5, 30)) - D = pt.matrix("D", shape=(30, 8)) - E = pt.matrix("E", shape=(8, 3)) - F = pt.matrix("F", shape=(3, 12)) - - out = pt.linalg.multi_dot([A, B, C]) @ pt.linalg.multi_dot([D, E, F]) - rewritten = rewrite_graph(out, include=("canonicalize",)) - expected = pt.linalg.multi_dot([A, B, C, D, E, F]) - assert_equal_computations([rewritten], [expected]) - - out = A @ B @ pt.linalg.multi_dot([C, D, E, F]) - rewritten = rewrite_graph(out, include=("canonicalize",)) - expected = pt.linalg.multi_dot([A, B, C, D, E, F]) - assert_equal_computations([rewritten], [expected]) - - out = pt.linalg.multi_dot([A, B, C, D]) @ E @ F - rewritten = rewrite_graph(out, include=("canonicalize",)) - expected = pt.linalg.multi_dot([A, B, C, D, E, F]) - assert_equal_computations([rewritten], [expected]) From d771280e6d390581f578c2fc6138b3008a22437d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 30 Apr 2026 21:58:13 -0500 Subject: [PATCH 4/7] re-implement multi_dot as reassociate rewrite --- pytensor/tensor/rewriting/linalg/__init__.py | 1 + .../rewriting/linalg/reassociate_matmul.py | 688 ++++++++++++++++++ .../linalg/test_reassociate_matmul.py | 499 +++++++++++++ 3 files changed, 1188 insertions(+) create mode 100644 pytensor/tensor/rewriting/linalg/reassociate_matmul.py create mode 100644 tests/tensor/rewriting/linalg/test_reassociate_matmul.py diff --git a/pytensor/tensor/rewriting/linalg/__init__.py b/pytensor/tensor/rewriting/linalg/__init__.py index 7e9d7b3393..3fbe4720e7 100644 --- a/pytensor/tensor/rewriting/linalg/__init__.py +++ b/pytensor/tensor/rewriting/linalg/__init__.py @@ -6,6 +6,7 @@ import pytensor.tensor.rewriting.linalg.decomposition import pytensor.tensor.rewriting.linalg.inverse import pytensor.tensor.rewriting.linalg.products +import pytensor.tensor.rewriting.linalg.reassociate_matmul import pytensor.tensor.rewriting.linalg.solvers import pytensor.tensor.rewriting.linalg.summary import pytensor.tensor.rewriting.linalg.utils diff --git a/pytensor/tensor/rewriting/linalg/reassociate_matmul.py b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py new file mode 100644 index 0000000000..86cdf45649 --- /dev/null +++ b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py @@ -0,0 +1,688 @@ +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import cast + +from pytensor.compile.mode import optdb +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import GraphRewriter, copy_stack_trace +from pytensor.tensor.blas import BatchedDot, Dot22 +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Dot, matmul + + +# A "dim entry" is an int (statically known size) or a Variable (a scalar shape var +# from ShapeFeature, treated as an opaque positive symbol >= 1). A CostExpr is a sum +# of monomials in those symbols, with literal-int factors folded into each monomial's +# coefficient. + +DimEntry = int | Variable +Shape = tuple[DimEntry, ...] + + +def _is_one(d: DimEntry) -> bool: + return isinstance(d, int) and d == 1 + + +def _sym_sort_key(sym: Variable) -> tuple[str, int]: + """Stable sort key for symbol ordering inside a monomial. Prefers `name` so the + same dim across runs sorts deterministically; falls back to id only as a tiebreak + between unnamed symbols within one process.""" + return (getattr(sym, "name", None) or "", id(sym)) + + +class _BailOutError(Exception): + """Internal signal: the rewriter should skip this chain (not raise to the user).""" + + +class CostExpr: + """Polynomial in positive dim symbols. + + Stores its monomials as ``{sorted_(symbol,exp)_tuple: int_coef}``. The FLOP cost + of a matmul is (approximately) the product of all unique dim lengths. Given + inputs A of shape ``(a, b)`` and B of shape ``(b, c)``, computing ``A @ B`` + costs ``a * b * c``. We call such a product a "monomial". A chain of matmuls + sums these monomials to give the total cost. + """ + + __slots__ = ("monomials",) + + def __init__(self, monomials=None): + self.monomials: dict[tuple, int] = ( + dict(monomials) if monomials is not None else {} + ) + + @classmethod + def zero(cls) -> "CostExpr": + return cls() + + @classmethod + def from_dim_product(cls, dims: Sequence[DimEntry]) -> "CostExpr": + """Build a single monomial from the product of `dims`.""" + coef = 1 + sym_exps: dict[Variable, int] = defaultdict(int) + for d in dims: + if isinstance(d, int): + coef *= d + else: + sym_exps[d] += 1 + key = tuple(sorted(sym_exps.items(), key=lambda kv: _sym_sort_key(kv[0]))) + return cls({key: coef}) + + def __add__(self, other: "CostExpr") -> "CostExpr": + result = dict(self.monomials) + for k, v in other.monomials.items(): + result[k] = result.get(k, 0) + v + if result[k] == 0: + del result[k] + return CostExpr(result) + + def __repr__(self) -> str: + if not self.monomials: + return "CostExpr(0)" + terms = [] + for sym_exps, coef in self.monomials.items(): + parts = [str(coef)] if coef != 1 or not sym_exps else [] + for sym, exp in sym_exps: + name = getattr(sym, "name", None) or "" + parts.append(f"{name}^{exp}" if exp != 1 else name) + terms.append("*".join(parts) if parts else "1") + return "CostExpr(" + " + ".join(terms) + ")" + + +def _provably_less(a: CostExpr, b: CostExpr) -> bool: + """True iff a < b is provable assuming every dim symbol is >= 1. + + Match each a-monomial to a distinct b-monomial that dominates it (b's coef >= a's, + b's exponents >= a's). A complete matching is sound (a <= b); strict (a < b) iff + sum(b.coefs) > sum(a.coefs). Matching uses Kuhn's algorithm. Returns False for + "not provable"; never claims b <= a. + """ + if not a.monomials: + return any(c > 0 for c in b.monomials.values()) + if not b.monomials: + return False + + a_terms = list(a.monomials.items()) + b_terms = list(b.monomials.items()) + n_a = len(a_terms) + n_b = len(b_terms) + if n_b < n_a: + return False + + # For each a-index, list the b-indices that can cover it. + candidates: list[list[int]] = [] + for a_key, a_coef in a_terms: + a_exp = dict(a_key) + row = [] + for b_idx, (b_key, b_coef) in enumerate(b_terms): + if b_coef < a_coef: + continue + b_exp = dict(b_key) + if all(b_exp.get(s, 0) >= e for s, e in a_exp.items()): + row.append(b_idx) + if not row: + return False + candidates.append(row) + + # Kuhn's algorithm: for each a-index, try to find an augmenting path. + matched_b_to_a = [-1] * n_b + + def try_assign(a_idx: int, seen: list[bool]) -> bool: + for b_idx in candidates[a_idx]: + if seen[b_idx]: + continue + seen[b_idx] = True + if matched_b_to_a[b_idx] == -1 or try_assign(matched_b_to_a[b_idx], seen): + matched_b_to_a[b_idx] = a_idx + return True + return False + + for a_idx in range(n_a): + if not try_assign(a_idx, [False] * n_b): + return False + + return sum(b.monomials.values()) > sum(a.monomials.values()) + + +def _operand_shape_raw(var: Variable, fgraph: FunctionGraph) -> Shape: + """Raw shape of `var` as a tuple of dim entries. + + Each entry is an int when statically known and a scalar Variable from + ShapeFeature otherwise. Falls back to ``var.shape[i]`` when ShapeFeature is absent. + """ + static = var.type.shape + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is not None and var in shape_feature.shape_of: + symbolic = shape_feature.shape_of[var] + else: + symbolic = tuple(var.shape[i] for i in range(var.type.ndim)) # type: ignore[attr-defined] + return tuple(int(s) if s is not None else symbolic[i] for i, s in enumerate(static)) + + +def _apply_lift_to_shape(lift: DimShuffle, shape: Shape) -> Shape: + """Apply a DimShuffle to a shape tuple. ``'x'`` becomes literal 1; ints index `shape`. + + Raises ``_BailOutError`` when `lift` references an input dim outside `shape`. This + happens when the chain extender constructed `lift` against a wider matmul output + (e.g., a Blockwise with broadcasting between heterogeneous-ndim operands) and the + decomposition is now propagating it down to a narrower operand it can't legally + apply to. + """ + out: list[DimEntry] = [] + for x in lift.new_order: + if x == "x": + out.append(1) + continue + if not (0 <= x < len(shape)): + raise _BailOutError( + f"DimShuffle.new_order references index {x} outside operand shape " + f"of length {len(shape)}; lift cannot legally apply." + ) + out.append(shape[x]) + return tuple(out) + + +def _broadcast_batch(left_batch: Shape, right_batch: Shape) -> Shape: + """Right-align two batch tuples and broadcast, preferring the non-literal-1 side.""" + n = max(len(left_batch), len(right_batch)) + pad_l = (1,) * (n - len(left_batch)) + tuple(left_batch) + pad_r = (1,) * (n - len(right_batch)) + tuple(right_batch) + return tuple(b if _is_one(a) else a for a, b in zip(pad_l, pad_r)) + + +def _matmul_result_shape(left: Shape, right: Shape) -> Shape: + """Shape of ``left @ right``: ``(*broadcast_batch, m, n)`` where left ends in + ``(m, k)`` and right ends in ``(k, n)``.""" + batch = _broadcast_batch(left[:-2], right[:-2]) + return (*batch, left[-2], right[-1]) + + +def _contract_cost(left: Shape, right: Shape) -> CostExpr: + """FLOPs of ``left @ right``: ``prod(broadcast_batch) * m * k * n``.""" + result = _matmul_result_shape(left, right) + return CostExpr.from_dim_product([*result[:-2], left[-2], left[-1], right[-1]]) + + +def _classify_dimshuffle_lift(op: DimShuffle, input_ndim: int) -> tuple[bool, bool]: + """Classify a DimShuffle for matmul-lift purposes. + + A DimShuffle commutes with matmul iff it touches only batch dimensions (the + leading ``input_ndim - 2``), or it performs a matrix-transpose (swap of the last + two dims) -- possibly combined with batch-only operations. Returns + ``(is_liftable, swaps_order)``, where ``swaps_order`` is True iff the lift swaps + operand order in the matmul: ``(L @ R).T = R.T @ L.T``. + + DimShuffle disallows duplicate input indices in `new_order`, so once the core dim + indices appear in the last two positions of the output they cannot appear + elsewhere -- earlier positions are batch-only. + """ + if input_ndim < 2: + return False, False + new_order = op.new_order + if len(new_order) < 2: + return False, False + last_two = (new_order[-2], new_order[-1]) + if last_two == (input_ndim - 2, input_ndim - 1): + return True, False + if last_two == (input_ndim - 1, input_ndim - 2): + return True, True + return False, False + + +def _is_chain_link(node: Apply) -> bool: + """True iff `node` is a 2-operand matmul we can rebuild via ``matmul()``.""" + op = node.op + if isinstance(op, Dot | Dot22 | BatchedDot): + return True + if isinstance(op, Blockwise) and isinstance(op.core_op, Dot): + return True + return False + + +def _find_chain_top(start: Apply, fgraph: FunctionGraph) -> Apply: + """Walk up to the topmost chain-link consumer along single-client edges. + + Follows single-client edges where the current output feeds the consumer's *left* + input, and walks through a single-client liftable DimShuffle when its output then + feeds a chain-link's left input. + """ + current = start + while True: + clients = fgraph.clients[current.outputs[0]] + if len(clients) != 1: + break + client_node, client_idx = clients[0] + # `client_node` may be the literal "output" sentinel for fgraph outputs; + # the type stub is too narrow to express this. + if not isinstance(client_node, Apply): + break # type: ignore[unreachable] + + if _is_chain_link(client_node) and client_idx == 0: + current = client_node + continue + + if ( + isinstance(client_node.op, DimShuffle) + and client_idx == 0 + and len(fgraph.clients[client_node.outputs[0]]) == 1 + ): + ds_consumer, ds_consumer_idx = fgraph.clients[client_node.outputs[0]][0] + if ( + isinstance(ds_consumer, Apply) + and _is_chain_link(ds_consumer) + and ds_consumer_idx == 0 + ): + liftable, _ = _classify_dimshuffle_lift( + client_node.op, current.outputs[0].type.ndim + ) + if liftable: + current = ds_consumer + continue + + break + return current + + +def _decompose_operand( + root: Variable, + fgraph: FunctionGraph, + visited: set[Apply], + consumed: list[Apply], +) -> list[tuple[Variable, tuple[DimShuffle, ...]]]: + """Iteratively decompose `root` into chain leaves. + + Each leaf is ``(base_var, lifts)`` where `lifts` is a tuple of DimShuffle Ops to + apply at materialization time (outermost first). Replace recursion with an + explicit work stack to keep the Python stack bounded for deep matmul chains + (autodiff-generated graphs can produce hundreds of links). + + Two descent paths grow the chain: + + - Single-client chain-link matmul: descend into both inputs. + - Single-client liftable DimShuffle wrapping a single-client chain-link matmul: + descend into the inner matmul's two inputs with the DimShuffle prepended to the + inherited-lift list. For matrix-transpose lifts, swapping operand order: + (``(L @ R).T = R.T @ L.T``). + + A lift only descends when both inner-matmul operands have ndim equal to the + DimShuffle's ``input_ndim``; this guards against propagating a lift into a + Blockwise(Dot) where the operands have heterogeneous ndim (the lift's + ``new_order`` would reference indices that don't exist on the narrower operand). + + Append each chain-link Apply we descend into to `consumed` and add to `visited`. + """ + out: list[tuple[Variable, tuple[DimShuffle, ...]]] = [] + # Push children right-then-left so the leftmost leaf comes out first (preserving + # the natural left-to-right operand order). + stack: list[tuple[Variable, tuple[DimShuffle, ...]]] = [(root, ())] + + while stack: + var, lifts = stack.pop() + owner = var.owner + if owner is None: + out.append((var, lifts)) + continue + + if isinstance(owner.op, DimShuffle) and len(fgraph.clients[var]) == 1: + ds_input = owner.inputs[0] + inner = ds_input.owner + if ( + inner is not None + and _is_chain_link(inner) + and inner not in visited + and len(fgraph.clients[ds_input]) == 1 + ): + ds_op = owner.op + liftable, swaps = _classify_dimshuffle_lift(ds_op, ds_input.type.ndim) + if liftable and all( + inp.type.ndim == ds_op.input_ndim for inp in inner.inputs + ): + visited.add(inner) + consumed.append(inner) + new_lifts = (ds_op, *lifts) + left, right = inner.inputs + if swaps: + left, right = right, left + stack.append((right, new_lifts)) + stack.append((left, new_lifts)) + continue + + if ( + _is_chain_link(owner) + and owner not in visited + and len(fgraph.clients[var]) == 1 + ): + visited.add(owner) + consumed.append(owner) + stack.append((owner.inputs[1], lifts)) + stack.append((owner.inputs[0], lifts)) + continue + + out.append((var, lifts)) + + return out + + +def _operand_shape( + operand: tuple[Variable, tuple[DimShuffle, ...]], fgraph: FunctionGraph +) -> Shape: + """Compute a chain operand's shape after applying its pending lifts.""" + base, lifts = operand + shape = _operand_shape_raw(base, fgraph) + # `lifts` is outermost-first; apply outermost-last so the innermost transformation + # touches the raw shape first. + for lift in reversed(lifts): + shape = _apply_lift_to_shape(lift, shape) + return shape + + +def _build_unification( + chain_shapes: list[Shape], extra_shapes: Sequence[Shape] = () +) -> tuple[list[Shape], Callable[[Shape], Shape]]: + """Canonicalize dim entries across chain operand shapes via union-find. + + Returns ``(unified_chain_shapes, canonicalize)`` where ``canonicalize(shape)`` maps + any Shape (over the same dim entries) to its canonical representatives. Two + equality sources drive the union-find: + + - Adjacent contracting dims of the chain. + - For each right-aligned batch position, all non-literal-1 entries across *every* + operand must agree at runtime (broadcasting requires it); unioning them as one + class catches transitive equalities a 1 in the middle would otherwise mask. + + A literal-int conflict (``ra != rb`` for two ints in the same class) signals an + inconsistent input graph -- raise ``_BailOutError`` so the caller skips the + rewrite rather than aborting compilation. The unification also seeds `parent` + with `extra_shapes` so the caller can canonicalize shapes outside the chain + (e.g., raw inputs of consumed inner matmuls). + """ + parent: dict[DimEntry, DimEntry] = {} + for shape in (*chain_shapes, *extra_shapes): + for d in shape: + parent.setdefault(d, d) + + def find(k: DimEntry) -> DimEntry: + while parent[k] != k: + parent[k] = parent[parent[k]] + k = parent[k] + return k + + def union(a: DimEntry, b: DimEntry) -> None: + ra, rb = find(a), find(b) + if ra == rb: + return + if isinstance(ra, int) and isinstance(rb, int): + if ra != rb: + raise _BailOutError( + f"Conflicting static dims in matmul chain: {ra} != {rb}." + ) + parent[rb] = ra + elif isinstance(ra, int): + parent[rb] = ra + elif isinstance(rb, int): + parent[ra] = rb + else: + parent[rb] = ra + + for i in range(len(chain_shapes) - 1): + union(chain_shapes[i][-1], chain_shapes[i + 1][-2]) + + # Batch broadcast: at each right-aligned position across the whole chain, all + # non-literal-1 entries must be equal. Unioning them as one class catches + # transitive equalities a literal-1 in the middle would otherwise mask. + max_batch = max((len(s) - 2 for s in chain_shapes), default=0) + for pos in range(max_batch): + anchor: DimEntry | None = None + for s in chain_shapes: + n_batch = len(s) - 2 + if pos >= n_batch: + continue + d = s[n_batch - 1 - pos] + if _is_one(d): + continue + if anchor is None: + anchor = d + else: + union(anchor, d) + + rep_entry: dict[DimEntry, DimEntry] = {} + for shape in (*chain_shapes, *extra_shapes): + for d in shape: + r = find(d) + if r not in rep_entry: + rep_entry[r] = d + elif isinstance(d, int) and not isinstance(rep_entry[r], int): + rep_entry[r] = d + + def canonicalize(shape: Shape) -> Shape: + return tuple(rep_entry[find(d)] if d in parent else d for d in shape) + + unified = [canonicalize(s) for s in chain_shapes] + return unified, canonicalize + + +def _solve_chain(shapes: list[Shape]) -> tuple[CostExpr, dict]: + """Standard matrix-chain DP over the symbolic operand shapes. + + Returns ``(best_total_cost, dp_table)`` where ``dp_table[(i, j)]`` is + ``(cost, split_k, result_shape)`` for the subchain ``shapes[i..j]`` inclusive. + Comparison uses ``_provably_less``; when neither candidate is provably less the + first-found candidate (lower split index) wins as a deterministic tie-break. + """ + n = len(shapes) + dp: dict[tuple[int, int], tuple[CostExpr, int | None, Shape]] = {} + for i in range(n): + dp[(i, i)] = (CostExpr.zero(), None, shapes[i]) + + for length in range(2, n + 1): + for i in range(n - length + 1): + j = i + length - 1 + best: tuple[CostExpr, int, Shape] | None = None + for k in range(i, j): + lc, _, ls = dp[(i, k)] + rc, _, rs = dp[(k + 1, j)] + step = _contract_cost(ls, rs) + total = lc + rc + step + result = _matmul_result_shape(ls, rs) + if best is None or _provably_less(total, best[0]): + best = (total, k, result) + + # Length >= 2 always has at least one split -- invariant violated. + assert best is not None, f"DP found no candidate for subchain ({i},{j})" + + dp[(i, j)] = best + + return dp[(0, n - 1)][0], dp + + +def _existing_cost( + consumed: list[Apply], + fgraph: FunctionGraph, + canonicalize: Callable[[Shape], Shape], +) -> CostExpr: + """Total FLOPs of the user's existing chain. + + Walks consumed matmul nodes in topological order (reversed insertion order, since + ``_decompose_operand`` adds the top first then descends). Each step looks up its + input shapes in the running ``var_shape`` table -- chain leaves take shapes from + ``_operand_shape_raw + canonicalize``; intermediate matmul outputs come from + ``_matmul_result_shape``. Lifted DimShuffles preserve FLOPs (they only touch + size-1 batch dims or swap core dims), so the canonicalized raw-shape sum + compares directly to ``_solve_chain``'s symbolic cost. + """ + var_shape: dict[Variable, Shape] = {} + total = CostExpr.zero() + for node in reversed(consumed): + l_input, r_input = node.inputs + if l_input not in var_shape: + var_shape[l_input] = canonicalize(_operand_shape_raw(l_input, fgraph)) + if r_input not in var_shape: + var_shape[r_input] = canonicalize(_operand_shape_raw(r_input, fgraph)) + l_shape = var_shape[l_input] + r_shape = var_shape[r_input] + total = total + _contract_cost(l_shape, r_shape) + var_shape[node.outputs[0]] = _matmul_result_shape(l_shape, r_shape) + return total + + +# Mirrors the dtype gate inside Dot22.make_node (pytensor/tensor/blas.py). If that +# list grows (e.g., bfloat16) update both -- there's no canonical export to import. +_BLAS_DTYPES = ("float16", "float32", "float64", "complex64", "complex128") + + +def _select_emit_op(left: Variable, right: Variable) -> Variable: + """Pick the cheapest op equivalent to ``left @ right`` *without changing semantics*. + + ``Dot22`` handles 2-D float/complex pairs safely (no broadcasting possible). + ``BatchedDot`` does **not** handle broadcasting -- its ``perform``/C path errors + when ``x.shape[0] != y.shape[0]``. Emit it only when both static batch dims are + known and equal. Anything else falls through to ``matmul()``, which lowers to + ``Blockwise(Dot)`` and broadcasts correctly. + """ + l_dt, r_dt = left.type.dtype, right.type.dtype + if l_dt != r_dt or l_dt not in _BLAS_DTYPES: + return matmul(left, right) # type: ignore[arg-type,no-any-return] + if left.type.ndim == right.type.ndim == 2: + return cast(Variable, Dot22()(left, right)) + if left.type.ndim == right.type.ndim == 3: + l_batch, r_batch = left.type.shape[0], right.type.shape[0] + if l_batch is not None and r_batch is not None and l_batch == r_batch: + return cast(Variable, BatchedDot()(left, right)) + return matmul(left, right) # type: ignore[arg-type,no-any-return] + + +def _build_tree( + operands: list[tuple[Variable, tuple[DimShuffle, ...]]], + dp: dict, + i_top: int, + j_top: int, +) -> Variable: + """Materialize the optimal matmul tree from the DP table. + + Walks the DP split tree in post-order using an explicit work stack -- same + reason ``_decompose_operand`` avoids recursion: deep chains can blow the + Python stack. + """ + materialized: dict[tuple[int, int], Variable] = {} + work: list[tuple[int, int, bool]] = [(i_top, j_top, False)] + + while work: + i, j, ready = work.pop() + if i == j: + var, lifts = operands[i] + for lift in reversed(lifts): + var = cast(Variable, lift(var)) + materialized[(i, j)] = var + continue + _, split, _ = dp[(i, j)] + if not ready: + work.append((i, j, True)) + work.append((split + 1, j, False)) + work.append((i, split, False)) + else: + left = materialized.pop((i, split)) + right = materialized.pop((split + 1, j)) + materialized[(i, j)] = _select_emit_op(left, right) + + return materialized[(i_top, j_top)] + + +class ReassociateMatmulChain(GraphRewriter): + """Re-associate matmul chains when a strictly cheaper order can be proven. + + Runs after BLAS (1.7) and specialize (2.0). For each maximal chain of matmul + links in the fgraph, runs a DP over symbolic operand shapes. Replaces the chain + only when the new total cost provably beats the user's existing parenthesization + (under "every dim symbol is positive"). + + Decomposes liftable single-client ``DimShuffle(matmul(L, R))`` patterns into + ``DimShuffle(L) @ DimShuffle(R)`` (swapping operand order for matrix-transpose) + to expose longer chains. The lift commits atomically with the reassociation: + ``_build_tree`` constructs the lifted vars and runs only when ``_provably_less`` + fires, so unsuccessful attempts add no graph nodes. + + The pass walks a snapshot toposort once: replacements inside the loop create new + chain-link nodes that this pass does not re-examine. For matrix chain + reassociation a single global optimum doesn't get better by re-running; the + trade-off is that the pass skips lift-exposed extensions of an already-rewritten + chain. + """ + + def apply(self, fgraph: FunctionGraph) -> None: + # `visited` only tracks nodes consumed by *committed* rewrites. Toposort + # visits leaves before consumers, so a leaf later absorbed into a longer + # chain (via the consumer) must remain decomposable when the consumer is + # reached. Without this, processing `C @ D` first in `(A @ B) @ (C @ D)` + # would mark it visited and block the outer matmul from seeing the full + # 4-element chain. + visited: set[Apply] = set() + + for node in list(fgraph.toposort()): + if node in visited or not _is_chain_link(node): + continue + + top = _find_chain_top(node, fgraph) + if top in visited: + continue + + # Local snapshot so the recursive decomposition can avoid re-entering + # the same chain link within this attempt without polluting `visited` + # until we commit. + local_visited = set(visited) + local_visited.add(top) + consumed: list[Apply] = [top] + try: + left_ops = _decompose_operand( + top.inputs[0], fgraph, local_visited, consumed + ) + right_ops = _decompose_operand( + top.inputs[1], fgraph, local_visited, consumed + ) + operands = [*left_ops, *right_ops] + + if len(operands) < 3: + continue + + op_shapes = [_operand_shape(op, fgraph) for op in operands] + if any(len(s) < 2 for s in op_shapes): + continue + + # Pre-collect raw input shapes of every consumed matmul so + # unification canonicalizes those symbols too. `_existing_cost` + # then uses the same canonical reps as the DP, so the comparison + # can see equalities through ShapeFeature symbols on either side + # of the chain. + raw_extras = [ + _operand_shape_raw(inp, fgraph) + for c in consumed + for inp in c.inputs + ] + + unified, canonicalize = _build_unification(op_shapes, raw_extras) + new_cost, dp = _solve_chain(unified) + old_cost = _existing_cost(consumed, fgraph, canonicalize) + except _BailOutError: + continue + + if not _provably_less(new_cost, old_cost): + continue + + visited.update(consumed) + + new_out = _build_tree(operands, dp, 0, len(operands) - 1) + old_out = top.outputs[0] + copy_stack_trace(old_out, new_out) + fgraph.replace(old_out, new_out, reason="reassociate_matmul") + + +reassociate_matmul_chain = ReassociateMatmulChain() +reassociate_matmul_chain.name = "reassociate_matmul_chain" + +optdb.register( + "reassociate_matmul_chain", + reassociate_matmul_chain, + "fast_run", + position=2.5, +) diff --git a/tests/tensor/rewriting/linalg/test_reassociate_matmul.py b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py new file mode 100644 index 0000000000..0b6c58e3f2 --- /dev/null +++ b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py @@ -0,0 +1,499 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import function +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.blas import BatchedDot, Dot22, Gemm +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Dot +from pytensor.tensor.rewriting.linalg.reassociate_matmul import ( + CostExpr, + _provably_less, + reassociate_matmul_chain, +) + + +def _matmul_nodes(fgraph): + """Apply nodes that perform a matmul-like contraction.""" + nodes = [] + for n in fgraph.toposort(): + op = n.op + if isinstance(op, Dot | Dot22 | Gemm | BatchedDot): + nodes.append(n) + elif isinstance(op, Blockwise) and isinstance(op.core_op, Dot): + nodes.append(n) + return nodes + + +def _matmul_output_shapes(fgraph): + """Static shapes of every matmul output (last entry is the final output).""" + return [n.outputs[0].type.shape for n in _matmul_nodes(fgraph)] + + +def _dimshuffle_count(fgraph): + return sum(1 for n in fgraph.toposort() if isinstance(n.op, DimShuffle)) + + +class TestProvablyLess: + def test_int_dominance(self): + a = CostExpr.from_dim_product([10, 20, 5]) # 1000 + b = CostExpr.from_dim_product([10, 20, 5]) + CostExpr.from_dim_product( + [1, 2, 3] + ) # 1006 + assert _provably_less(a, b) + assert not _provably_less(b, a) + assert not _provably_less(a, a) + + def test_zero_vs_positive(self): + a = CostExpr.zero() + b = CostExpr.from_dim_product([7]) + assert _provably_less(a, b) + assert not _provably_less(b, a) + assert not _provably_less(a, a) + + def test_symbolic_no_proof(self): + # m*k*n vs k*n*p: different monomials, no dominance either way. + m = pt.scalar("m") + k = pt.scalar("k") + n = pt.scalar("n") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k, n]) + b = CostExpr.from_dim_product([k, n, p]) + assert not _provably_less(a, b) + assert not _provably_less(b, a) + + def test_extra_unmatched_b_monomial(self): + # a = m*k, b = m*k + p: b strictly larger under p >= 1. + m = pt.scalar("m") + k = pt.scalar("k") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k]) + b = CostExpr.from_dim_product([m, k]) + CostExpr.from_dim_product([p]) + assert _provably_less(a, b) + assert not _provably_less(b, a) + + def test_b_dominates_with_extra_symbol(self): + # a = m*k, b = m*k*p: b is sound but not strict (p == 1 makes them equal). + m = pt.scalar("m") + k = pt.scalar("k") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k]) + b = CostExpr.from_dim_product([m, k, p]) + assert not _provably_less(a, b) + + def test_greedy_fails_kuhn_succeeds(self): + # Greedy first-fit pairs a[0]=x with b's first dominator (x*y), leaving + # a[1]=x*y with only x*z, which doesn't dominate. The valid matching is + # x <-> x*z and x*y <-> x*y; the unmatched `w` makes the strict comparison + # succeed. + x, y, z, w = pt.scalar("x"), pt.scalar("y"), pt.scalar("z"), pt.scalar("w") + a = CostExpr.from_dim_product([x]) + CostExpr.from_dim_product([x, y]) + b = ( + CostExpr.from_dim_product([x, y]) + + CostExpr.from_dim_product([x, z]) + + CostExpr.from_dim_product([w]) + ) + assert _provably_less(a, b) + + +class TestReassociateChain: + def test_optimal_static_ordering(self): + # (100x2) @ (2x100) @ (100x100): optimal is A @ (B @ C) with B@C producing + # a (2, 100) intermediate, far cheaper than (A @ B) @ C. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + + f = function([A, B, C], A @ B @ C) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, ( + f"Expected the optimal split to produce a (2, 100) intermediate, " + f"got {shapes}" + ) + + def test_value_equivalence_4_matrices(self): + rng = np.random.default_rng(0) + shapes = [(10, 20), (20, 5), (5, 30), (30, 3)] + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in shapes] + pt_inputs = [pt.matrix(f"M{i}", shape=s) for i, s in enumerate(shapes)] + out = pt_inputs[0] @ pt_inputs[1] @ pt_inputs[2] @ pt_inputs[3] + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 + ) + # Optimal split for these shapes is A @ (B @ (C @ D)) with cost 1350 vs. + # naive 3400. Pin the smallest internal intermediate as a sentinel that the + # rewriter chose this tree. + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (5, 3) in shapes_seen, shapes_seen + + def test_symbolic_shapes_preserve_user_order(self): + # When all dims are symbolic, no parenthesization is provably cheaper, so + # the rewriter must leave the user's order intact. + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.matrix("C") + out = A @ (B @ C) + f = function([A, B, C], out) + + nodes = _matmul_nodes(f.maker.fgraph) + assert len(nodes) == 2 + # The first matmul (in topo order) should be the BC contraction -- its + # inputs are exactly B and C, not the result of A @ anything. + first = nodes[0] + assert set(first.inputs) == {B, C}, ( + f"Expected first matmul to be B @ C, got inputs {first.inputs}" + ) + + def test_already_optimal_no_rewrite(self): + # User-written order is already optimal: (A @ B) -> (2, 100) @ C -> (2, 5) + # at cost 20_000 vs. the alternative A @ (B @ C) at 50_000. The rewriter + # must not disturb this. + A = pt.matrix("A", shape=(2, 100)) + B = pt.matrix("B", shape=(100, 100)) + C = pt.matrix("C", shape=(100, 5)) + f = function([A, B, C], (A @ B) @ C) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert shapes == [(2, 100), (2, 5)], ( + f"Expected the user's two-step order to be preserved exactly; got {shapes}" + ) + + def test_multi_client_breakage(self): + # If an intermediate matmul output has more than one client, the chain + # extender must stop there -- absorbing past it would duplicate the + # intermediate's computation when emitting the new tree. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + D = pt.matrix("D", shape=(100, 5)) + ab = A @ B # (100, 100), used twice + out1 = ab @ C @ D + out2 = ab.sum() + f = function([A, B, C, D], [out1, out2]) + rng = np.random.default_rng(1) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 100)).astype(config.floatX) + d_v = rng.normal(size=(100, 5)).astype(config.floatX) + r1, r2 = f(a_v, b_v, c_v, d_v) + ab_v = a_v @ b_v + np.testing.assert_allclose(r1, ab_v @ c_v @ d_v, rtol=1e-5) + np.testing.assert_allclose(r2, ab_v.sum(), rtol=1e-5) + + def test_quadratic_form_static(self): + # A @ B @ A.T with A: (50, 100), B: (100, 100). The transpose makes the + # third operand opaquely (100, 50). Both parenthesizations cost 750_000: + # (A @ B) @ A.T: 50*100*100 + 50*100*50 + # A @ (B @ A.T): 100*100*50 + 50*100*50 + # The rewriter has no provable win and must leave the user's order alone. + A = pt.matrix("A", shape=(50, 100)) + B = pt.matrix("B", shape=(100, 100)) + f = function([A, B], A @ B @ A.T) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert shapes == [(50, 100), (50, 50)], ( + f"Expected user's symmetric chain to be preserved; got {shapes}" + ) + + def test_batched_blockwise_dot_reorders(self): + # Batched chain shaped to force a reorder: A(7,100,2) B(7,2,100) C(7,100,5). + # Naive (A@B)@C: 7*100*2*100 + 7*100*100*5 = 490k. + # Optimal A@(B@C): 7*2*100*5 + 7*100*2*5 = 14k. + rng = np.random.default_rng(2) + shapes = [(7, 100, 2), (7, 2, 100), (7, 100, 5)] + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in shapes] + pt_inputs = [pt.tensor3(f"B{i}", shape=s) for i, s in enumerate(shapes)] + out = pt_inputs[0] @ pt_inputs[1] @ pt_inputs[2] + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), + np_arrays[0] @ np_arrays[1] @ np_arrays[2], + rtol=1e-5, + ) + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (7, 2, 5) in shapes_seen, shapes_seen + + def test_dot22_chain_post_blas(self): + # By the time the rewriter runs (post-BLAS), 2-D dots may have been + # promoted to Dot22. The chain extender must treat Dot22 as a chain link + # so the chain spans the full sequence AND so the rebuilt tree retains the + # BLAS-promoted form. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + f = function([A, B, C], A @ B @ C) + nodes = _matmul_nodes(f.maker.fgraph) + ops_seen = {type(n.op).__name__ for n in nodes} + assert ops_seen & {"Dot22", "Gemm"}, ( + f"Expected BLAS-promoted nodes to survive reassociation; got {ops_seen}" + ) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, shapes + + @pytest.mark.parametrize( + "chain_shapes,must_contain", + [ + # Classic textbook 3-matrix: A(10x30) B(30x5) C(5x60). Optimal is + # (A@B)@C with intermediate (10, 5). + ([(10, 30), (30, 5), (5, 60)], {(10, 5)}), + # 5-matrix CLRS-style chain. Optimal is (A(BC))(DE) -- intermediates + # are (35, 5), (30, 5), (5, 20), (30, 20). Two of those are the + # required-present sentinels. + ([(30, 35), (35, 15), (15, 5), (5, 10), (10, 20)], {(35, 5), (5, 20)}), + ], + ) + def test_textbook_examples(self, chain_shapes, must_contain): + rng = np.random.default_rng(3) + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in chain_shapes] + pt_inputs = [pt.matrix(f"M{i}", shape=s) for i, s in enumerate(chain_shapes)] + out = pt_inputs[0] + for x in pt_inputs[1:]: + out = out @ x + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 + ) + shapes = set(_matmul_output_shapes(f.maker.fgraph)) + missing = must_contain - shapes + assert not missing, f"Missing expected intermediates {missing}; got {shapes}" + + def test_balanced_tree_decomposition(self): + # User's explicit `(A @ B) @ (C @ D)` is a non-linear tree. The recursive + # decomposer must still see the full 4-operand chain through both subtrees. + # Shapes are chosen so the user's order has three (100, 100) contractions + # costing ~1M FLOPs and the optimal `A @ (B @ (C @ D))` avoids them + # entirely (~21k FLOPs). The unavoidable (100, 100) is the final output; + # any *internal* (100, 100) means the rewriter missed the cross-subtree + # chain. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 3)) + D = pt.matrix("D", shape=(3, 100)) + f = function([A, B, C, D], (A @ B) @ (C @ D)) + + rng = np.random.default_rng(7) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 3)).astype(config.floatX) + d_v = rng.normal(size=(3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), (a_v @ b_v) @ (c_v @ d_v), rtol=1e-5 + ) + + shapes = _matmul_output_shapes(f.maker.fgraph) + n_big = sum(1 for s in shapes if s == (100, 100)) + assert n_big <= 1, ( + f"Expected at most one (100,100) (the final output); got {shapes}" + ) + + +class TestLifting: + def test_lift_expand_dims_recovers_chain(self): + # `expand_dims(A @ B, 0) @ C @ D` only sees three operands without lifting, + # and the (A @ B) intermediate is forced to (100, 100). After lifting the + # expand_dims past A @ B, the chain becomes 4-element and the DP can avoid + # the large intermediate by routing through the small m=2/k=3 dims. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.tensor3("C", shape=(1, 100, 3)) + D = pt.tensor3("D", shape=(1, 3, 100)) + ab3 = pt.expand_dims(A @ B, 0) + f = function([A, B, C, D], ab3 @ C @ D) + + rng = np.random.default_rng(11) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(1, 100, 3)).astype(config.floatX) + d_v = rng.normal(size=(1, 3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), (a_v @ b_v)[None] @ c_v @ d_v, rtol=1e-5 + ) + + # The unavoidable (100, 100) (or (1, 100, 100)) is the final output. Any + # internal contraction with that shape means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (1, 100, 100) not in non_final and (100, 100) not in non_final, ( + f"Lift+reorder should avoid the (100,100) intermediate; got {shapes}" + ) + + def test_lift_matrix_transpose(self): + # `(L @ R).T @ C` lifts to `R.T @ L.T @ C` (operand order swapped). Shapes + # are chosen so the lifted form's `R.T @ (L.T @ C)` is dramatically cheaper + # than the user's `(L @ R).T @ C` -- verifies that matrix-transpose lift + # exposes a real reorder, not just a no-op restructuring. + # L (100, 2), R (2, 50), C (100, 100): + # user: 100*2*50 + 50*100*100 = 510_000 + # lifted [R.T (50,2), L.T (2,100), C (100,100)] with R.T @ (L.T @ C): + # L.T @ C = 2*100*100 = 20_000; R.T @ that = 50*2*100 = 10_000. + L = pt.matrix("L", shape=(100, 2)) + R = pt.matrix("R", shape=(2, 50)) + C = pt.matrix("C", shape=(100, 100)) + f = function([L, R, C], (L @ R).T @ C) + + rng = np.random.default_rng(13) + l_v = rng.normal(size=(100, 2)).astype(config.floatX) + r_v = rng.normal(size=(2, 50)).astype(config.floatX) + c_v = rng.normal(size=(100, 100)).astype(config.floatX) + np.testing.assert_allclose(f(l_v, r_v, c_v), (l_v @ r_v).T @ c_v, rtol=1e-5) + + # Optimal lifted tree's intermediate is L.T @ C = (2, 100). + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, ( + f"Expected matrix-transpose lift to enable a (2, 100) intermediate; " + f"got {shapes}" + ) + + def test_lift_atomic_gating_no_extra_nodes(self): + # All-symbolic shapes -> no provable win -> no rewrite. The lift must NOT + # add extra DimShuffle nodes when the rewrite isn't committed. Test the + # rewriter standalone on a fresh fgraph (no other passes) so the assertion + # isolates its behavior. + L = pt.matrix("L") + R = pt.matrix("R") + C = pt.matrix("C") + D = pt.matrix("D") + out = pt.expand_dims(L @ R, 0) @ pt.expand_dims(C, 0) @ pt.expand_dims(D, 0) + + fgraph = FunctionGraph([L, R, C, D], [out], clone=False) + before = _dimshuffle_count(fgraph) + reassociate_matmul_chain.apply(fgraph) + after = _dimshuffle_count(fgraph) + assert after == before, ( + f"Speculative lift leaked DimShuffle nodes: before={before} after={after}" + ) + + def test_lift_squeeze(self): + # `squeeze(A @ B, 0) @ C @ D` where (A@B) has a leading-1 batch dim. After + # the squeeze lift the chain spans 4 operands and the DP can route through + # the small k=2/k=3 dims rather than the (100, 100) intermediate the + # original chain forced. + A = pt.tensor3("A", shape=(1, 100, 2)) + B = pt.tensor3("B", shape=(1, 2, 100)) + C = pt.matrix("C", shape=(100, 3)) + D = pt.matrix("D", shape=(3, 100)) + ab2 = pt.squeeze(A @ B, axis=0) + f = function([A, B, C, D], ab2 @ C @ D) + + rng = np.random.default_rng(23) + a_v = rng.normal(size=(1, 100, 2)).astype(config.floatX) + b_v = rng.normal(size=(1, 2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 3)).astype(config.floatX) + d_v = rng.normal(size=(3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), + (a_v @ b_v).squeeze(axis=0) @ c_v @ d_v, + rtol=1e-5, + ) + # Final output is (100, 100); any internal one means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (100, 100) not in non_final, ( + f"Squeeze lift should expose a chain that avoids the (100,100) " + f"intermediate; got {shapes}" + ) + + def test_lift_through_heterogeneous_ndim_does_not_crash(self): + # A DimShuffle wrapping a Blockwise(Dot) whose operands have different + # ndims (one 2-D, one 3-D) cannot have its lift propagated to both inner + # operands -- the lift's `new_order` references indices the 2-D operand + # doesn't have. The rewriter must bail and treat the wrapper as opaque + # rather than crashing or producing a malformed graph. + L = pt.matrix("L", shape=(4, 5)) + R = pt.tensor3("R", shape=(7, 5, 6)) # broadcasts L to (7, 4, 6) + C = pt.tensor3("C", shape=(7, 6, 8)) + ds_inner = pt.expand_dims(L @ R, 0) + f = function([L, R, C], ds_inner @ pt.expand_dims(C, 0)) + + rng = np.random.default_rng(29) + l_v = rng.normal(size=(4, 5)).astype(config.floatX) + r_v = rng.normal(size=(7, 5, 6)).astype(config.floatX) + c_v = rng.normal(size=(7, 6, 8)).astype(config.floatX) + np.testing.assert_allclose( + f(l_v, r_v, c_v), + (l_v @ r_v)[None] @ c_v[None], + rtol=1e-5, + ) + + def test_no_batched_dot_when_batch_dims_could_broadcast(self): + # BatchedDot does not broadcast -- its perform/C path errors at runtime + # when batch dims differ. If the chain has one operand with static batch + # dim 1 and another with non-1 batch dim, the rewriter must NOT emit + # BatchedDot for those pairs; matmul/Blockwise(Dot) is the only safe + # choice. + A = pt.tensor3("A", shape=(1, 4, 5)) + B = pt.tensor3("B", shape=(7, 5, 6)) + C = pt.tensor3("C", shape=(7, 6, 4)) + f = function([A, B, C], A @ B @ C) + + # Any BatchedDot in the rewritten graph must have both operands' static + # batch dim known and equal. Anything else is a latent runtime crash. + for node in _matmul_nodes(f.maker.fgraph): + if isinstance(node.op, BatchedDot): + lb = node.inputs[0].type.shape[0] + rb = node.inputs[1].type.shape[0] + assert lb is not None and rb is not None and lb == rb, ( + f"BatchedDot with broadcasting-incompatible batch dims: " + f"left={node.inputs[0].type.shape}, " + f"right={node.inputs[1].type.shape}" + ) + + rng = np.random.default_rng(31) + a_v = rng.normal(size=(1, 4, 5)).astype(config.floatX) + b_v = rng.normal(size=(7, 5, 6)).astype(config.floatX) + c_v = rng.normal(size=(7, 6, 4)).astype(config.floatX) + np.testing.assert_allclose(f(a_v, b_v, c_v), a_v @ b_v @ c_v, rtol=1e-5) + + def test_stacked_dimshuffle_lift_does_not_crash(self): + # `DimShuffle(DimShuffle(matmul))`: only the outer DimShuffle satisfies the + # "single-client wrapping a chain-link matmul" pattern -- the inner + # DimShuffle is not a chain link itself. The lift attempt on the outer + # must bail (treat the whole stack as opaque) and produce a correct + # graph. + A = pt.matrix("A", shape=(2, 3)) + B = pt.matrix("B", shape=(3, 4)) + C = pt.tensor4("C", shape=(1, 1, 4, 5)) + D = pt.tensor4("D", shape=(1, 1, 5, 6)) + ab_stacked = pt.expand_dims(pt.expand_dims(A @ B, 0), 0) + f = function([A, B, C, D], ab_stacked @ C @ D) + + rng = np.random.default_rng(37) + a_v = rng.normal(size=(2, 3)).astype(config.floatX) + b_v = rng.normal(size=(3, 4)).astype(config.floatX) + c_v = rng.normal(size=(1, 1, 4, 5)).astype(config.floatX) + d_v = rng.normal(size=(1, 1, 5, 6)).astype(config.floatX) + ab_stacked_v = (a_v @ b_v)[None, None] + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), ab_stacked_v @ c_v @ d_v, rtol=1e-5 + ) + + def test_lift_blocked_by_multi_client_dimshuffle(self): + # If the DimShuffle has multiple clients, lifting would duplicate work for + # the OTHER clients. The rewriter must NOT lift in this case. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.tensor3("C", shape=(1, 100, 3)) + D = pt.tensor3("D", shape=(1, 3, 100)) + ab3 = pt.expand_dims(A @ B, 0) + out1 = ab3 @ C @ D + out2 = ab3 + f = function([A, B, C, D], [out1, out2]) + rng = np.random.default_rng(19) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(1, 100, 3)).astype(config.floatX) + d_v = rng.normal(size=(1, 3, 100)).astype(config.floatX) + ab3_v = (a_v @ b_v)[None] + r1, r2 = f(a_v, b_v, c_v, d_v) + np.testing.assert_allclose(r1, ab3_v @ c_v @ d_v, rtol=1e-5) + np.testing.assert_allclose(r2, ab3_v, rtol=1e-5) + # The original (A @ B) must still be in the graph because ab3 is also + # consumed by out2; if we wrongly lifted the expand_dims, A @ B would be + # gone. + nodes = _matmul_nodes(f.maker.fgraph) + assert any(n.outputs[0].type.shape == (100, 100) for n in nodes), ( + f"Expected A@B (100,100) to remain because ab3 has multiple clients; " + f"got {[n.outputs[0].type.shape for n in nodes]}" + ) From e89740a32a3318d6d285544e0b9c93cdd8a2c883 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 2 May 2026 18:33:48 -0500 Subject: [PATCH 5/7] Skip rewrite tests in FAST_COMPILE mode --- .../linalg/test_reassociate_matmul.py | 87 +++++++++++-------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/tests/tensor/rewriting/linalg/test_reassociate_matmul.py b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py index 0b6c58e3f2..87e6cfa6db 100644 --- a/tests/tensor/rewriting/linalg/test_reassociate_matmul.py +++ b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py @@ -16,6 +16,12 @@ ) +requires_fast_run = pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="reassociate_matmul_chain is only registered for FAST_RUN", +) + + def _matmul_nodes(fgraph): """Apply nodes that perform a matmul-like contraction.""" nodes = [] @@ -100,6 +106,7 @@ def test_greedy_fails_kuhn_succeeds(self): class TestReassociateChain: + @requires_fast_run def test_optimal_static_ordering(self): # (100x2) @ (2x100) @ (100x100): optimal is A @ (B @ C) with B@C producing # a (2, 100) intermediate, far cheaper than (A @ B) @ C. @@ -124,11 +131,12 @@ def test_value_equivalence_4_matrices(self): np.testing.assert_allclose( f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 ) - # Optimal split for these shapes is A @ (B @ (C @ D)) with cost 1350 vs. - # naive 3400. Pin the smallest internal intermediate as a sentinel that the - # rewriter chose this tree. - shapes_seen = _matmul_output_shapes(f.maker.fgraph) - assert (5, 3) in shapes_seen, shapes_seen + if config.mode != "FAST_COMPILE": + # Optimal split for these shapes is A @ (B @ (C @ D)) with cost 1350 + # vs. naive 3400. Pin the smallest internal intermediate as a + # sentinel that the rewriter chose this tree. + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (5, 3) in shapes_seen, shapes_seen def test_symbolic_shapes_preserve_user_order(self): # When all dims are symbolic, no parenthesization is provably cheaper, so @@ -212,9 +220,11 @@ def test_batched_blockwise_dot_reorders(self): np_arrays[0] @ np_arrays[1] @ np_arrays[2], rtol=1e-5, ) - shapes_seen = _matmul_output_shapes(f.maker.fgraph) - assert (7, 2, 5) in shapes_seen, shapes_seen + if config.mode != "FAST_COMPILE": + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (7, 2, 5) in shapes_seen, shapes_seen + @requires_fast_run def test_dot22_chain_post_blas(self): # By the time the rewriter runs (post-BLAS), 2-D dots may have been # promoted to Dot22. The chain extender must treat Dot22 as a chain link @@ -255,9 +265,12 @@ def test_textbook_examples(self, chain_shapes, must_contain): np.testing.assert_allclose( f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 ) - shapes = set(_matmul_output_shapes(f.maker.fgraph)) - missing = must_contain - shapes - assert not missing, f"Missing expected intermediates {missing}; got {shapes}" + if config.mode != "FAST_COMPILE": + shapes = set(_matmul_output_shapes(f.maker.fgraph)) + missing = must_contain - shapes + assert not missing, ( + f"Missing expected intermediates {missing}; got {shapes}" + ) def test_balanced_tree_decomposition(self): # User's explicit `(A @ B) @ (C @ D)` is a non-linear tree. The recursive @@ -282,11 +295,12 @@ def test_balanced_tree_decomposition(self): f(a_v, b_v, c_v, d_v), (a_v @ b_v) @ (c_v @ d_v), rtol=1e-5 ) - shapes = _matmul_output_shapes(f.maker.fgraph) - n_big = sum(1 for s in shapes if s == (100, 100)) - assert n_big <= 1, ( - f"Expected at most one (100,100) (the final output); got {shapes}" - ) + if config.mode != "FAST_COMPILE": + shapes = _matmul_output_shapes(f.maker.fgraph) + n_big = sum(1 for s in shapes if s == (100, 100)) + assert n_big <= 1, ( + f"Expected at most one (100,100) (the final output); got {shapes}" + ) class TestLifting: @@ -311,13 +325,14 @@ def test_lift_expand_dims_recovers_chain(self): f(a_v, b_v, c_v, d_v), (a_v @ b_v)[None] @ c_v @ d_v, rtol=1e-5 ) - # The unavoidable (100, 100) (or (1, 100, 100)) is the final output. Any - # internal contraction with that shape means the lift didn't fire. - shapes = _matmul_output_shapes(f.maker.fgraph) - non_final = shapes[:-1] if shapes else [] - assert (1, 100, 100) not in non_final and (100, 100) not in non_final, ( - f"Lift+reorder should avoid the (100,100) intermediate; got {shapes}" - ) + if config.mode != "FAST_COMPILE": + # The unavoidable (100, 100) (or (1, 100, 100)) is the final output. + # Any internal contraction with that shape means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (1, 100, 100) not in non_final and (100, 100) not in non_final, ( + f"Lift+reorder should avoid the (100,100) intermediate; got {shapes}" + ) def test_lift_matrix_transpose(self): # `(L @ R).T @ C` lifts to `R.T @ L.T @ C` (operand order swapped). Shapes @@ -339,12 +354,13 @@ def test_lift_matrix_transpose(self): c_v = rng.normal(size=(100, 100)).astype(config.floatX) np.testing.assert_allclose(f(l_v, r_v, c_v), (l_v @ r_v).T @ c_v, rtol=1e-5) - # Optimal lifted tree's intermediate is L.T @ C = (2, 100). - shapes = _matmul_output_shapes(f.maker.fgraph) - assert (2, 100) in shapes, ( - f"Expected matrix-transpose lift to enable a (2, 100) intermediate; " - f"got {shapes}" - ) + if config.mode != "FAST_COMPILE": + # Optimal lifted tree's intermediate is L.T @ C = (2, 100). + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, ( + f"Expected matrix-transpose lift to enable a (2, 100) " + f"intermediate; got {shapes}" + ) def test_lift_atomic_gating_no_extra_nodes(self): # All-symbolic shapes -> no provable win -> no rewrite. The lift must NOT @@ -387,13 +403,14 @@ def test_lift_squeeze(self): (a_v @ b_v).squeeze(axis=0) @ c_v @ d_v, rtol=1e-5, ) - # Final output is (100, 100); any internal one means the lift didn't fire. - shapes = _matmul_output_shapes(f.maker.fgraph) - non_final = shapes[:-1] if shapes else [] - assert (100, 100) not in non_final, ( - f"Squeeze lift should expose a chain that avoids the (100,100) " - f"intermediate; got {shapes}" - ) + if config.mode != "FAST_COMPILE": + # Final output is (100, 100); any internal one means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (100, 100) not in non_final, ( + f"Squeeze lift should expose a chain that avoids the (100,100) " + f"intermediate; got {shapes}" + ) def test_lift_through_heterogeneous_ndim_does_not_crash(self): # A DimShuffle wrapping a Blockwise(Dot) whose operands have different From 9d1a298806debc85816cb523b6b76302b90f37fd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 5 May 2026 22:39:26 -0500 Subject: [PATCH 6/7] cleanup --- .../rewriting/linalg/reassociate_matmul.py | 216 +++++++----------- 1 file changed, 87 insertions(+), 129 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg/reassociate_matmul.py b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py index 86cdf45649..6ae4ae2d8a 100644 --- a/pytensor/tensor/rewriting/linalg/reassociate_matmul.py +++ b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py @@ -21,21 +21,6 @@ Shape = tuple[DimEntry, ...] -def _is_one(d: DimEntry) -> bool: - return isinstance(d, int) and d == 1 - - -def _sym_sort_key(sym: Variable) -> tuple[str, int]: - """Stable sort key for symbol ordering inside a monomial. Prefers `name` so the - same dim across runs sorts deterministically; falls back to id only as a tiebreak - between unnamed symbols within one process.""" - return (getattr(sym, "name", None) or "", id(sym)) - - -class _BailOutError(Exception): - """Internal signal: the rewriter should skip this chain (not raise to the user).""" - - class CostExpr: """Polynomial in positive dim symbols. @@ -59,7 +44,12 @@ def zero(cls) -> "CostExpr": @classmethod def from_dim_product(cls, dims: Sequence[DimEntry]) -> "CostExpr": - """Build a single monomial from the product of `dims`.""" + """Build a single monomial from the product of `dims`. + + Symbol ordering within the monomial sorts on ``name`` first so the same dim + across runs sorts deterministically, falling back to ``id`` only as a tiebreak + between unnamed symbols within one process. + """ coef = 1 sym_exps: dict[Variable, int] = defaultdict(int) for d in dims: @@ -67,7 +57,12 @@ def from_dim_product(cls, dims: Sequence[DimEntry]) -> "CostExpr": coef *= d else: sym_exps[d] += 1 - key = tuple(sorted(sym_exps.items(), key=lambda kv: _sym_sort_key(kv[0]))) + key = tuple( + sorted( + sym_exps.items(), + key=lambda kv: (getattr(kv[0], "name", None) or "", id(kv[0])), + ) + ) return cls({key: coef}) def __add__(self, other: "CostExpr") -> "CostExpr": @@ -161,41 +156,17 @@ def _operand_shape_raw(var: Variable, fgraph: FunctionGraph) -> Shape: return tuple(int(s) if s is not None else symbolic[i] for i, s in enumerate(static)) -def _apply_lift_to_shape(lift: DimShuffle, shape: Shape) -> Shape: - """Apply a DimShuffle to a shape tuple. ``'x'`` becomes literal 1; ints index `shape`. - - Raises ``_BailOutError`` when `lift` references an input dim outside `shape`. This - happens when the chain extender constructed `lift` against a wider matmul output - (e.g., a Blockwise with broadcasting between heterogeneous-ndim operands) and the - decomposition is now propagating it down to a narrower operand it can't legally - apply to. - """ - out: list[DimEntry] = [] - for x in lift.new_order: - if x == "x": - out.append(1) - continue - if not (0 <= x < len(shape)): - raise _BailOutError( - f"DimShuffle.new_order references index {x} outside operand shape " - f"of length {len(shape)}; lift cannot legally apply." - ) - out.append(shape[x]) - return tuple(out) - - -def _broadcast_batch(left_batch: Shape, right_batch: Shape) -> Shape: - """Right-align two batch tuples and broadcast, preferring the non-literal-1 side.""" +def _matmul_result_shape(left: Shape, right: Shape) -> Shape: + """Shape of ``left @ right``: right-align the leading batch tuples, broadcast + them (preferring the non-literal-1 side), then append ``(m, n)`` from + ``left[-2]`` and ``right[-1]``.""" + left_batch, right_batch = left[:-2], right[:-2] n = max(len(left_batch), len(right_batch)) pad_l = (1,) * (n - len(left_batch)) + tuple(left_batch) pad_r = (1,) * (n - len(right_batch)) + tuple(right_batch) - return tuple(b if _is_one(a) else a for a, b in zip(pad_l, pad_r)) - - -def _matmul_result_shape(left: Shape, right: Shape) -> Shape: - """Shape of ``left @ right``: ``(*broadcast_batch, m, n)`` where left ends in - ``(m, k)`` and right ends in ``(k, n)``.""" - batch = _broadcast_batch(left[:-2], right[:-2]) + batch = tuple( + b if (isinstance(a, int) and a == 1) else a for a, b in zip(pad_l, pad_r) + ) return (*batch, left[-2], right[-1]) @@ -300,16 +271,15 @@ def _decompose_operand( Two descent paths grow the chain: - - Single-client chain-link matmul: descend into both inputs. + - Single-client chain-link matmul: descend into both inputs. When the parent + carries inherited lifts, both children must share the parent's output ndim -- + otherwise an inherited lift would reference indices missing on a narrower + operand (``Blockwise(Dot)`` broadcasts heterogeneous-ndim operands). - Single-client liftable DimShuffle wrapping a single-client chain-link matmul: descend into the inner matmul's two inputs with the DimShuffle prepended to the inherited-lift list. For matrix-transpose lifts, swapping operand order: - (``(L @ R).T = R.T @ L.T``). - - A lift only descends when both inner-matmul operands have ndim equal to the - DimShuffle's ``input_ndim``; this guards against propagating a lift into a - Blockwise(Dot) where the operands have heterogeneous ndim (the lift's - ``new_order`` would reference indices that don't exist on the narrower operand). + (``(L @ R).T = R.T @ L.T``). Both inner-matmul operands must have ndim equal + to the DimShuffle's ``input_ndim`` for the same reason. Append each chain-link Apply we descend into to `consumed` and add to `visited`. """ @@ -353,6 +323,9 @@ def _decompose_operand( _is_chain_link(owner) and owner not in visited and len(fgraph.clients[var]) == 1 + and ( + not lifts or all(inp.type.ndim == var.type.ndim for inp in owner.inputs) + ) ): visited.add(owner) consumed.append(owner) @@ -365,19 +338,6 @@ def _decompose_operand( return out -def _operand_shape( - operand: tuple[Variable, tuple[DimShuffle, ...]], fgraph: FunctionGraph -) -> Shape: - """Compute a chain operand's shape after applying its pending lifts.""" - base, lifts = operand - shape = _operand_shape_raw(base, fgraph) - # `lifts` is outermost-first; apply outermost-last so the innermost transformation - # touches the raw shape first. - for lift in reversed(lifts): - shape = _apply_lift_to_shape(lift, shape) - return shape - - def _build_unification( chain_shapes: list[Shape], extra_shapes: Sequence[Shape] = () ) -> tuple[list[Shape], Callable[[Shape], Shape]]: @@ -392,11 +352,12 @@ def _build_unification( operand must agree at runtime (broadcasting requires it); unioning them as one class catches transitive equalities a 1 in the middle would otherwise mask. - A literal-int conflict (``ra != rb`` for two ints in the same class) signals an - inconsistent input graph -- raise ``_BailOutError`` so the caller skips the - rewrite rather than aborting compilation. The unification also seeds `parent` - with `extra_shapes` so the caller can canonicalize shapes outside the chain - (e.g., raw inputs of consumed inner matmuls). + The unification also seeds `parent` with `extra_shapes` so the caller can + canonicalize shapes outside the chain (e.g., raw inputs of consumed inner + matmuls). Two ints unifying to different values would mean the input graph is + ill-formed (matmul construction rejects mismatched contract dims and + non-broadcastable batch dims), so the union prefers the int representative + without checking for conflict. """ parent: dict[DimEntry, DimEntry] = {} for shape in (*chain_shapes, *extra_shapes): @@ -413,13 +374,7 @@ def union(a: DimEntry, b: DimEntry) -> None: ra, rb = find(a), find(b) if ra == rb: return - if isinstance(ra, int) and isinstance(rb, int): - if ra != rb: - raise _BailOutError( - f"Conflicting static dims in matmul chain: {ra} != {rb}." - ) - parent[rb] = ra - elif isinstance(ra, int): + if isinstance(ra, int): parent[rb] = ra elif isinstance(rb, int): parent[ra] = rb @@ -440,7 +395,7 @@ def union(a: DimEntry, b: DimEntry) -> None: if pos >= n_batch: continue d = s[n_batch - 1 - pos] - if _is_one(d): + if isinstance(d, int) and d == 1: continue if anchor is None: anchor = d @@ -504,13 +459,12 @@ def _existing_cost( ) -> CostExpr: """Total FLOPs of the user's existing chain. - Walks consumed matmul nodes in topological order (reversed insertion order, since - ``_decompose_operand`` adds the top first then descends). Each step looks up its - input shapes in the running ``var_shape`` table -- chain leaves take shapes from - ``_operand_shape_raw + canonicalize``; intermediate matmul outputs come from - ``_matmul_result_shape``. Lifted DimShuffles preserve FLOPs (they only touch - size-1 batch dims or swap core dims), so the canonicalized raw-shape sum - compares directly to ``_solve_chain``'s symbolic cost. + Walks consumed matmul nodes in topological order (reversed insertion order). Each step looks up its input shapes + in the running ``var_shape`` table. The chain leaves take shapes from ``_operand_shape_raw + canonicalize`` while + intermediate matmul outputs come from ``_matmul_result_shape``. + + Lifted DimShuffles preserve FLOPs (they only touch size-1 batch dims or swap core dims), so the canonicalized + raw-shape sum compares directly to ``_solve_chain``'s symbolic cost. """ var_shape: dict[Variable, Shape] = {} total = CostExpr.zero() @@ -533,13 +487,14 @@ def _existing_cost( def _select_emit_op(left: Variable, right: Variable) -> Variable: - """Pick the cheapest op equivalent to ``left @ right`` *without changing semantics*. + """Emit ``left @ right`` via the cheapest semantically-equivalent op. + + Routing: - ``Dot22`` handles 2-D float/complex pairs safely (no broadcasting possible). - ``BatchedDot`` does **not** handle broadcasting -- its ``perform``/C path errors - when ``x.shape[0] != y.shape[0]``. Emit it only when both static batch dims are - known and equal. Anything else falls through to ``matmul()``, which lowers to - ``Blockwise(Dot)`` and broadcasts correctly. + - 2-D float/complex pair: ``Dot22``. + - 3-D float/complex pair whose batch dims share static broadcastability (both statically ``1``, or neither + statically ``1``): ``BatchedDot``. + - Anything else: ``matmul()``, which lowers to ``Blockwise(Dot)``. """ l_dt, r_dt = left.type.dtype, right.type.dtype if l_dt != r_dt or l_dt not in _BLAS_DTYPES: @@ -547,8 +502,7 @@ def _select_emit_op(left: Variable, right: Variable) -> Variable: if left.type.ndim == right.type.ndim == 2: return cast(Variable, Dot22()(left, right)) if left.type.ndim == right.type.ndim == 3: - l_batch, r_batch = left.type.shape[0], right.type.shape[0] - if l_batch is not None and r_batch is not None and l_batch == r_batch: + if (left.type.shape[0] == 1) == (right.type.shape[0] == 1): return cast(Variable, BatchedDot()(left, right)) return matmul(left, right) # type: ignore[arg-type,no-any-return] @@ -561,9 +515,8 @@ def _build_tree( ) -> Variable: """Materialize the optimal matmul tree from the DP table. - Walks the DP split tree in post-order using an explicit work stack -- same - reason ``_decompose_operand`` avoids recursion: deep chains can blow the - Python stack. + Walks the DP split tree in post-order using an explicit work stack. Like ``_decompose_operand``, this function + avoids recursion because deep chains can blow the Python stack. """ materialized: dict[tuple[int, int], Variable] = {} work: list[tuple[int, int, bool]] = [(i_top, j_top, False)] @@ -633,39 +586,44 @@ def apply(self, fgraph: FunctionGraph) -> None: local_visited = set(visited) local_visited.add(top) consumed: list[Apply] = [top] - try: - left_ops = _decompose_operand( - top.inputs[0], fgraph, local_visited, consumed - ) - right_ops = _decompose_operand( - top.inputs[1], fgraph, local_visited, consumed - ) - operands = [*left_ops, *right_ops] - - if len(operands) < 3: - continue + left_ops = _decompose_operand( + top.inputs[0], fgraph, local_visited, consumed + ) + right_ops = _decompose_operand( + top.inputs[1], fgraph, local_visited, consumed + ) + operands = [*left_ops, *right_ops] - op_shapes = [_operand_shape(op, fgraph) for op in operands] - if any(len(s) < 2 for s in op_shapes): - continue + if len(operands) < 3: + continue - # Pre-collect raw input shapes of every consumed matmul so - # unification canonicalizes those symbols too. `_existing_cost` - # then uses the same canonical reps as the DP, so the comparison - # can see equalities through ShapeFeature symbols on either side - # of the chain. - raw_extras = [ - _operand_shape_raw(inp, fgraph) - for c in consumed - for inp in c.inputs - ] - - unified, canonicalize = _build_unification(op_shapes, raw_extras) - new_cost, dp = _solve_chain(unified) - old_cost = _existing_cost(consumed, fgraph, canonicalize) - except _BailOutError: + # Compute each operand's shape after applying its pending lifts. `lifts` + # is outermost-first; apply in reverse so the innermost transformation + # touches the raw shape first. ``_decompose_operand`` only propagates a + # lift through a chain-link whose operand ndim matches the lift's input + # ndim, so indexing into ``shape`` by ``lift.new_order`` is in-bounds. + op_shapes: list[Shape] = [] + for base, lifts in operands: + shape: Shape = _operand_shape_raw(base, fgraph) + for lift in reversed(lifts): + shape = tuple(1 if x == "x" else shape[x] for x in lift.new_order) + op_shapes.append(shape) + + if any(len(s) < 2 for s in op_shapes): continue + # Pre-collect raw input shapes of every consumed matmul so unification + # canonicalizes those symbols too. `_existing_cost` then uses the same + # canonical reps as the DP, so the comparison can see equalities through + # ShapeFeature symbols on either side of the chain. + raw_extras = [ + _operand_shape_raw(inp, fgraph) for c in consumed for inp in c.inputs + ] + + unified, canonicalize = _build_unification(op_shapes, raw_extras) + new_cost, dp = _solve_chain(unified) + old_cost = _existing_cost(consumed, fgraph, canonicalize) + if not _provably_less(new_cost, old_cost): continue From a289913c25e40aae4e8bf662729d76c5a9597e74 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 5 May 2026 22:44:24 -0500 Subject: [PATCH 7/7] deprecate matrix_dot --- pytensor/tensor/linalg/products.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 23c773bff3..aec9f5230b 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -1,3 +1,5 @@ +import warnings + import scipy.linalg as scipy_linalg from pytensor import tensor as pt @@ -126,6 +128,12 @@ def matrix_dot(*args): :math:`A_0 \cdot A_1 \cdot A_2 \cdot .. \cdot A_N`. """ + warnings.warn( + "matrix_dot is deprecated and will be removed in future version.", + DeprecationWarning, + stacklevel=2, + ) + rval = args[0] for a in args[1:]: rval = ptm.matmul(rval, a)