diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 1dd41fe51b..aec9f5230b 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -1,3 +1,5 @@ +import warnings + import scipy.linalg as scipy_linalg from pytensor import tensor as pt @@ -126,9 +128,15 @@ def matrix_dot(*args): :math:`A_0 \cdot A_1 \cdot A_2 \cdot .. \cdot A_N`. """ + warnings.warn( + "matrix_dot is deprecated and will be removed in future version.", + DeprecationWarning, + stacklevel=2, + ) + rval = args[0] for a in args[1:]: - rval = ptm.dot(rval, a) + rval = ptm.matmul(rval, a) return rval diff --git a/pytensor/tensor/rewriting/linalg/__init__.py b/pytensor/tensor/rewriting/linalg/__init__.py index 7e9d7b3393..3fbe4720e7 100644 --- a/pytensor/tensor/rewriting/linalg/__init__.py +++ b/pytensor/tensor/rewriting/linalg/__init__.py @@ -6,6 +6,7 @@ import pytensor.tensor.rewriting.linalg.decomposition import pytensor.tensor.rewriting.linalg.inverse import pytensor.tensor.rewriting.linalg.products +import pytensor.tensor.rewriting.linalg.reassociate_matmul import pytensor.tensor.rewriting.linalg.solvers import pytensor.tensor.rewriting.linalg.summary import pytensor.tensor.rewriting.linalg.utils diff --git a/pytensor/tensor/rewriting/linalg/products.py b/pytensor/tensor/rewriting/linalg/products.py index 4a76658887..1f69da680a 100644 --- a/pytensor/tensor/rewriting/linalg/products.py +++ b/pytensor/tensor/rewriting/linalg/products.py @@ -1,7 +1,4 @@ -from pytensor.graph.rewriting.basic import ( - copy_stack_trace, - node_rewriter, -) +from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import ExtractDiag, concatenate, diag from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg.constructors import BlockDiagonal diff --git a/pytensor/tensor/rewriting/linalg/reassociate_matmul.py b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py new file mode 100644 index 0000000000..6ae4ae2d8a --- /dev/null +++ b/pytensor/tensor/rewriting/linalg/reassociate_matmul.py @@ -0,0 +1,646 @@ +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import cast + +from pytensor.compile.mode import optdb +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import GraphRewriter, copy_stack_trace +from pytensor.tensor.blas import BatchedDot, Dot22 +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Dot, matmul + + +# A "dim entry" is an int (statically known size) or a Variable (a scalar shape var +# from ShapeFeature, treated as an opaque positive symbol >= 1). A CostExpr is a sum +# of monomials in those symbols, with literal-int factors folded into each monomial's +# coefficient. + +DimEntry = int | Variable +Shape = tuple[DimEntry, ...] + + +class CostExpr: + """Polynomial in positive dim symbols. + + Stores its monomials as ``{sorted_(symbol,exp)_tuple: int_coef}``. The FLOP cost + of a matmul is (approximately) the product of all unique dim lengths. Given + inputs A of shape ``(a, b)`` and B of shape ``(b, c)``, computing ``A @ B`` + costs ``a * b * c``. We call such a product a "monomial". A chain of matmuls + sums these monomials to give the total cost. + """ + + __slots__ = ("monomials",) + + def __init__(self, monomials=None): + self.monomials: dict[tuple, int] = ( + dict(monomials) if monomials is not None else {} + ) + + @classmethod + def zero(cls) -> "CostExpr": + return cls() + + @classmethod + def from_dim_product(cls, dims: Sequence[DimEntry]) -> "CostExpr": + """Build a single monomial from the product of `dims`. + + Symbol ordering within the monomial sorts on ``name`` first so the same dim + across runs sorts deterministically, falling back to ``id`` only as a tiebreak + between unnamed symbols within one process. + """ + coef = 1 + sym_exps: dict[Variable, int] = defaultdict(int) + for d in dims: + if isinstance(d, int): + coef *= d + else: + sym_exps[d] += 1 + key = tuple( + sorted( + sym_exps.items(), + key=lambda kv: (getattr(kv[0], "name", None) or "", id(kv[0])), + ) + ) + return cls({key: coef}) + + def __add__(self, other: "CostExpr") -> "CostExpr": + result = dict(self.monomials) + for k, v in other.monomials.items(): + result[k] = result.get(k, 0) + v + if result[k] == 0: + del result[k] + return CostExpr(result) + + def __repr__(self) -> str: + if not self.monomials: + return "CostExpr(0)" + terms = [] + for sym_exps, coef in self.monomials.items(): + parts = [str(coef)] if coef != 1 or not sym_exps else [] + for sym, exp in sym_exps: + name = getattr(sym, "name", None) or "" + parts.append(f"{name}^{exp}" if exp != 1 else name) + terms.append("*".join(parts) if parts else "1") + return "CostExpr(" + " + ".join(terms) + ")" + + +def _provably_less(a: CostExpr, b: CostExpr) -> bool: + """True iff a < b is provable assuming every dim symbol is >= 1. + + Match each a-monomial to a distinct b-monomial that dominates it (b's coef >= a's, + b's exponents >= a's). A complete matching is sound (a <= b); strict (a < b) iff + sum(b.coefs) > sum(a.coefs). Matching uses Kuhn's algorithm. Returns False for + "not provable"; never claims b <= a. + """ + if not a.monomials: + return any(c > 0 for c in b.monomials.values()) + if not b.monomials: + return False + + a_terms = list(a.monomials.items()) + b_terms = list(b.monomials.items()) + n_a = len(a_terms) + n_b = len(b_terms) + if n_b < n_a: + return False + + # For each a-index, list the b-indices that can cover it. + candidates: list[list[int]] = [] + for a_key, a_coef in a_terms: + a_exp = dict(a_key) + row = [] + for b_idx, (b_key, b_coef) in enumerate(b_terms): + if b_coef < a_coef: + continue + b_exp = dict(b_key) + if all(b_exp.get(s, 0) >= e for s, e in a_exp.items()): + row.append(b_idx) + if not row: + return False + candidates.append(row) + + # Kuhn's algorithm: for each a-index, try to find an augmenting path. + matched_b_to_a = [-1] * n_b + + def try_assign(a_idx: int, seen: list[bool]) -> bool: + for b_idx in candidates[a_idx]: + if seen[b_idx]: + continue + seen[b_idx] = True + if matched_b_to_a[b_idx] == -1 or try_assign(matched_b_to_a[b_idx], seen): + matched_b_to_a[b_idx] = a_idx + return True + return False + + for a_idx in range(n_a): + if not try_assign(a_idx, [False] * n_b): + return False + + return sum(b.monomials.values()) > sum(a.monomials.values()) + + +def _operand_shape_raw(var: Variable, fgraph: FunctionGraph) -> Shape: + """Raw shape of `var` as a tuple of dim entries. + + Each entry is an int when statically known and a scalar Variable from + ShapeFeature otherwise. Falls back to ``var.shape[i]`` when ShapeFeature is absent. + """ + static = var.type.shape + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is not None and var in shape_feature.shape_of: + symbolic = shape_feature.shape_of[var] + else: + symbolic = tuple(var.shape[i] for i in range(var.type.ndim)) # type: ignore[attr-defined] + return tuple(int(s) if s is not None else symbolic[i] for i, s in enumerate(static)) + + +def _matmul_result_shape(left: Shape, right: Shape) -> Shape: + """Shape of ``left @ right``: right-align the leading batch tuples, broadcast + them (preferring the non-literal-1 side), then append ``(m, n)`` from + ``left[-2]`` and ``right[-1]``.""" + left_batch, right_batch = left[:-2], right[:-2] + n = max(len(left_batch), len(right_batch)) + pad_l = (1,) * (n - len(left_batch)) + tuple(left_batch) + pad_r = (1,) * (n - len(right_batch)) + tuple(right_batch) + batch = tuple( + b if (isinstance(a, int) and a == 1) else a for a, b in zip(pad_l, pad_r) + ) + return (*batch, left[-2], right[-1]) + + +def _contract_cost(left: Shape, right: Shape) -> CostExpr: + """FLOPs of ``left @ right``: ``prod(broadcast_batch) * m * k * n``.""" + result = _matmul_result_shape(left, right) + return CostExpr.from_dim_product([*result[:-2], left[-2], left[-1], right[-1]]) + + +def _classify_dimshuffle_lift(op: DimShuffle, input_ndim: int) -> tuple[bool, bool]: + """Classify a DimShuffle for matmul-lift purposes. + + A DimShuffle commutes with matmul iff it touches only batch dimensions (the + leading ``input_ndim - 2``), or it performs a matrix-transpose (swap of the last + two dims) -- possibly combined with batch-only operations. Returns + ``(is_liftable, swaps_order)``, where ``swaps_order`` is True iff the lift swaps + operand order in the matmul: ``(L @ R).T = R.T @ L.T``. + + DimShuffle disallows duplicate input indices in `new_order`, so once the core dim + indices appear in the last two positions of the output they cannot appear + elsewhere -- earlier positions are batch-only. + """ + if input_ndim < 2: + return False, False + new_order = op.new_order + if len(new_order) < 2: + return False, False + last_two = (new_order[-2], new_order[-1]) + if last_two == (input_ndim - 2, input_ndim - 1): + return True, False + if last_two == (input_ndim - 1, input_ndim - 2): + return True, True + return False, False + + +def _is_chain_link(node: Apply) -> bool: + """True iff `node` is a 2-operand matmul we can rebuild via ``matmul()``.""" + op = node.op + if isinstance(op, Dot | Dot22 | BatchedDot): + return True + if isinstance(op, Blockwise) and isinstance(op.core_op, Dot): + return True + return False + + +def _find_chain_top(start: Apply, fgraph: FunctionGraph) -> Apply: + """Walk up to the topmost chain-link consumer along single-client edges. + + Follows single-client edges where the current output feeds the consumer's *left* + input, and walks through a single-client liftable DimShuffle when its output then + feeds a chain-link's left input. + """ + current = start + while True: + clients = fgraph.clients[current.outputs[0]] + if len(clients) != 1: + break + client_node, client_idx = clients[0] + # `client_node` may be the literal "output" sentinel for fgraph outputs; + # the type stub is too narrow to express this. + if not isinstance(client_node, Apply): + break # type: ignore[unreachable] + + if _is_chain_link(client_node) and client_idx == 0: + current = client_node + continue + + if ( + isinstance(client_node.op, DimShuffle) + and client_idx == 0 + and len(fgraph.clients[client_node.outputs[0]]) == 1 + ): + ds_consumer, ds_consumer_idx = fgraph.clients[client_node.outputs[0]][0] + if ( + isinstance(ds_consumer, Apply) + and _is_chain_link(ds_consumer) + and ds_consumer_idx == 0 + ): + liftable, _ = _classify_dimshuffle_lift( + client_node.op, current.outputs[0].type.ndim + ) + if liftable: + current = ds_consumer + continue + + break + return current + + +def _decompose_operand( + root: Variable, + fgraph: FunctionGraph, + visited: set[Apply], + consumed: list[Apply], +) -> list[tuple[Variable, tuple[DimShuffle, ...]]]: + """Iteratively decompose `root` into chain leaves. + + Each leaf is ``(base_var, lifts)`` where `lifts` is a tuple of DimShuffle Ops to + apply at materialization time (outermost first). Replace recursion with an + explicit work stack to keep the Python stack bounded for deep matmul chains + (autodiff-generated graphs can produce hundreds of links). + + Two descent paths grow the chain: + + - Single-client chain-link matmul: descend into both inputs. When the parent + carries inherited lifts, both children must share the parent's output ndim -- + otherwise an inherited lift would reference indices missing on a narrower + operand (``Blockwise(Dot)`` broadcasts heterogeneous-ndim operands). + - Single-client liftable DimShuffle wrapping a single-client chain-link matmul: + descend into the inner matmul's two inputs with the DimShuffle prepended to the + inherited-lift list. For matrix-transpose lifts, swapping operand order: + (``(L @ R).T = R.T @ L.T``). Both inner-matmul operands must have ndim equal + to the DimShuffle's ``input_ndim`` for the same reason. + + Append each chain-link Apply we descend into to `consumed` and add to `visited`. + """ + out: list[tuple[Variable, tuple[DimShuffle, ...]]] = [] + # Push children right-then-left so the leftmost leaf comes out first (preserving + # the natural left-to-right operand order). + stack: list[tuple[Variable, tuple[DimShuffle, ...]]] = [(root, ())] + + while stack: + var, lifts = stack.pop() + owner = var.owner + if owner is None: + out.append((var, lifts)) + continue + + if isinstance(owner.op, DimShuffle) and len(fgraph.clients[var]) == 1: + ds_input = owner.inputs[0] + inner = ds_input.owner + if ( + inner is not None + and _is_chain_link(inner) + and inner not in visited + and len(fgraph.clients[ds_input]) == 1 + ): + ds_op = owner.op + liftable, swaps = _classify_dimshuffle_lift(ds_op, ds_input.type.ndim) + if liftable and all( + inp.type.ndim == ds_op.input_ndim for inp in inner.inputs + ): + visited.add(inner) + consumed.append(inner) + new_lifts = (ds_op, *lifts) + left, right = inner.inputs + if swaps: + left, right = right, left + stack.append((right, new_lifts)) + stack.append((left, new_lifts)) + continue + + if ( + _is_chain_link(owner) + and owner not in visited + and len(fgraph.clients[var]) == 1 + and ( + not lifts or all(inp.type.ndim == var.type.ndim for inp in owner.inputs) + ) + ): + visited.add(owner) + consumed.append(owner) + stack.append((owner.inputs[1], lifts)) + stack.append((owner.inputs[0], lifts)) + continue + + out.append((var, lifts)) + + return out + + +def _build_unification( + chain_shapes: list[Shape], extra_shapes: Sequence[Shape] = () +) -> tuple[list[Shape], Callable[[Shape], Shape]]: + """Canonicalize dim entries across chain operand shapes via union-find. + + Returns ``(unified_chain_shapes, canonicalize)`` where ``canonicalize(shape)`` maps + any Shape (over the same dim entries) to its canonical representatives. Two + equality sources drive the union-find: + + - Adjacent contracting dims of the chain. + - For each right-aligned batch position, all non-literal-1 entries across *every* + operand must agree at runtime (broadcasting requires it); unioning them as one + class catches transitive equalities a 1 in the middle would otherwise mask. + + The unification also seeds `parent` with `extra_shapes` so the caller can + canonicalize shapes outside the chain (e.g., raw inputs of consumed inner + matmuls). Two ints unifying to different values would mean the input graph is + ill-formed (matmul construction rejects mismatched contract dims and + non-broadcastable batch dims), so the union prefers the int representative + without checking for conflict. + """ + parent: dict[DimEntry, DimEntry] = {} + for shape in (*chain_shapes, *extra_shapes): + for d in shape: + parent.setdefault(d, d) + + def find(k: DimEntry) -> DimEntry: + while parent[k] != k: + parent[k] = parent[parent[k]] + k = parent[k] + return k + + def union(a: DimEntry, b: DimEntry) -> None: + ra, rb = find(a), find(b) + if ra == rb: + return + if isinstance(ra, int): + parent[rb] = ra + elif isinstance(rb, int): + parent[ra] = rb + else: + parent[rb] = ra + + for i in range(len(chain_shapes) - 1): + union(chain_shapes[i][-1], chain_shapes[i + 1][-2]) + + # Batch broadcast: at each right-aligned position across the whole chain, all + # non-literal-1 entries must be equal. Unioning them as one class catches + # transitive equalities a literal-1 in the middle would otherwise mask. + max_batch = max((len(s) - 2 for s in chain_shapes), default=0) + for pos in range(max_batch): + anchor: DimEntry | None = None + for s in chain_shapes: + n_batch = len(s) - 2 + if pos >= n_batch: + continue + d = s[n_batch - 1 - pos] + if isinstance(d, int) and d == 1: + continue + if anchor is None: + anchor = d + else: + union(anchor, d) + + rep_entry: dict[DimEntry, DimEntry] = {} + for shape in (*chain_shapes, *extra_shapes): + for d in shape: + r = find(d) + if r not in rep_entry: + rep_entry[r] = d + elif isinstance(d, int) and not isinstance(rep_entry[r], int): + rep_entry[r] = d + + def canonicalize(shape: Shape) -> Shape: + return tuple(rep_entry[find(d)] if d in parent else d for d in shape) + + unified = [canonicalize(s) for s in chain_shapes] + return unified, canonicalize + + +def _solve_chain(shapes: list[Shape]) -> tuple[CostExpr, dict]: + """Standard matrix-chain DP over the symbolic operand shapes. + + Returns ``(best_total_cost, dp_table)`` where ``dp_table[(i, j)]`` is + ``(cost, split_k, result_shape)`` for the subchain ``shapes[i..j]`` inclusive. + Comparison uses ``_provably_less``; when neither candidate is provably less the + first-found candidate (lower split index) wins as a deterministic tie-break. + """ + n = len(shapes) + dp: dict[tuple[int, int], tuple[CostExpr, int | None, Shape]] = {} + for i in range(n): + dp[(i, i)] = (CostExpr.zero(), None, shapes[i]) + + for length in range(2, n + 1): + for i in range(n - length + 1): + j = i + length - 1 + best: tuple[CostExpr, int, Shape] | None = None + for k in range(i, j): + lc, _, ls = dp[(i, k)] + rc, _, rs = dp[(k + 1, j)] + step = _contract_cost(ls, rs) + total = lc + rc + step + result = _matmul_result_shape(ls, rs) + if best is None or _provably_less(total, best[0]): + best = (total, k, result) + + # Length >= 2 always has at least one split -- invariant violated. + assert best is not None, f"DP found no candidate for subchain ({i},{j})" + + dp[(i, j)] = best + + return dp[(0, n - 1)][0], dp + + +def _existing_cost( + consumed: list[Apply], + fgraph: FunctionGraph, + canonicalize: Callable[[Shape], Shape], +) -> CostExpr: + """Total FLOPs of the user's existing chain. + + Walks consumed matmul nodes in topological order (reversed insertion order). Each step looks up its input shapes + in the running ``var_shape`` table. The chain leaves take shapes from ``_operand_shape_raw + canonicalize`` while + intermediate matmul outputs come from ``_matmul_result_shape``. + + Lifted DimShuffles preserve FLOPs (they only touch size-1 batch dims or swap core dims), so the canonicalized + raw-shape sum compares directly to ``_solve_chain``'s symbolic cost. + """ + var_shape: dict[Variable, Shape] = {} + total = CostExpr.zero() + for node in reversed(consumed): + l_input, r_input = node.inputs + if l_input not in var_shape: + var_shape[l_input] = canonicalize(_operand_shape_raw(l_input, fgraph)) + if r_input not in var_shape: + var_shape[r_input] = canonicalize(_operand_shape_raw(r_input, fgraph)) + l_shape = var_shape[l_input] + r_shape = var_shape[r_input] + total = total + _contract_cost(l_shape, r_shape) + var_shape[node.outputs[0]] = _matmul_result_shape(l_shape, r_shape) + return total + + +# Mirrors the dtype gate inside Dot22.make_node (pytensor/tensor/blas.py). If that +# list grows (e.g., bfloat16) update both -- there's no canonical export to import. +_BLAS_DTYPES = ("float16", "float32", "float64", "complex64", "complex128") + + +def _select_emit_op(left: Variable, right: Variable) -> Variable: + """Emit ``left @ right`` via the cheapest semantically-equivalent op. + + Routing: + + - 2-D float/complex pair: ``Dot22``. + - 3-D float/complex pair whose batch dims share static broadcastability (both statically ``1``, or neither + statically ``1``): ``BatchedDot``. + - Anything else: ``matmul()``, which lowers to ``Blockwise(Dot)``. + """ + l_dt, r_dt = left.type.dtype, right.type.dtype + if l_dt != r_dt or l_dt not in _BLAS_DTYPES: + return matmul(left, right) # type: ignore[arg-type,no-any-return] + if left.type.ndim == right.type.ndim == 2: + return cast(Variable, Dot22()(left, right)) + if left.type.ndim == right.type.ndim == 3: + if (left.type.shape[0] == 1) == (right.type.shape[0] == 1): + return cast(Variable, BatchedDot()(left, right)) + return matmul(left, right) # type: ignore[arg-type,no-any-return] + + +def _build_tree( + operands: list[tuple[Variable, tuple[DimShuffle, ...]]], + dp: dict, + i_top: int, + j_top: int, +) -> Variable: + """Materialize the optimal matmul tree from the DP table. + + Walks the DP split tree in post-order using an explicit work stack. Like ``_decompose_operand``, this function + avoids recursion because deep chains can blow the Python stack. + """ + materialized: dict[tuple[int, int], Variable] = {} + work: list[tuple[int, int, bool]] = [(i_top, j_top, False)] + + while work: + i, j, ready = work.pop() + if i == j: + var, lifts = operands[i] + for lift in reversed(lifts): + var = cast(Variable, lift(var)) + materialized[(i, j)] = var + continue + _, split, _ = dp[(i, j)] + if not ready: + work.append((i, j, True)) + work.append((split + 1, j, False)) + work.append((i, split, False)) + else: + left = materialized.pop((i, split)) + right = materialized.pop((split + 1, j)) + materialized[(i, j)] = _select_emit_op(left, right) + + return materialized[(i_top, j_top)] + + +class ReassociateMatmulChain(GraphRewriter): + """Re-associate matmul chains when a strictly cheaper order can be proven. + + Runs after BLAS (1.7) and specialize (2.0). For each maximal chain of matmul + links in the fgraph, runs a DP over symbolic operand shapes. Replaces the chain + only when the new total cost provably beats the user's existing parenthesization + (under "every dim symbol is positive"). + + Decomposes liftable single-client ``DimShuffle(matmul(L, R))`` patterns into + ``DimShuffle(L) @ DimShuffle(R)`` (swapping operand order for matrix-transpose) + to expose longer chains. The lift commits atomically with the reassociation: + ``_build_tree`` constructs the lifted vars and runs only when ``_provably_less`` + fires, so unsuccessful attempts add no graph nodes. + + The pass walks a snapshot toposort once: replacements inside the loop create new + chain-link nodes that this pass does not re-examine. For matrix chain + reassociation a single global optimum doesn't get better by re-running; the + trade-off is that the pass skips lift-exposed extensions of an already-rewritten + chain. + """ + + def apply(self, fgraph: FunctionGraph) -> None: + # `visited` only tracks nodes consumed by *committed* rewrites. Toposort + # visits leaves before consumers, so a leaf later absorbed into a longer + # chain (via the consumer) must remain decomposable when the consumer is + # reached. Without this, processing `C @ D` first in `(A @ B) @ (C @ D)` + # would mark it visited and block the outer matmul from seeing the full + # 4-element chain. + visited: set[Apply] = set() + + for node in list(fgraph.toposort()): + if node in visited or not _is_chain_link(node): + continue + + top = _find_chain_top(node, fgraph) + if top in visited: + continue + + # Local snapshot so the recursive decomposition can avoid re-entering + # the same chain link within this attempt without polluting `visited` + # until we commit. + local_visited = set(visited) + local_visited.add(top) + consumed: list[Apply] = [top] + left_ops = _decompose_operand( + top.inputs[0], fgraph, local_visited, consumed + ) + right_ops = _decompose_operand( + top.inputs[1], fgraph, local_visited, consumed + ) + operands = [*left_ops, *right_ops] + + if len(operands) < 3: + continue + + # Compute each operand's shape after applying its pending lifts. `lifts` + # is outermost-first; apply in reverse so the innermost transformation + # touches the raw shape first. ``_decompose_operand`` only propagates a + # lift through a chain-link whose operand ndim matches the lift's input + # ndim, so indexing into ``shape`` by ``lift.new_order`` is in-bounds. + op_shapes: list[Shape] = [] + for base, lifts in operands: + shape: Shape = _operand_shape_raw(base, fgraph) + for lift in reversed(lifts): + shape = tuple(1 if x == "x" else shape[x] for x in lift.new_order) + op_shapes.append(shape) + + if any(len(s) < 2 for s in op_shapes): + continue + + # Pre-collect raw input shapes of every consumed matmul so unification + # canonicalizes those symbols too. `_existing_cost` then uses the same + # canonical reps as the DP, so the comparison can see equalities through + # ShapeFeature symbols on either side of the chain. + raw_extras = [ + _operand_shape_raw(inp, fgraph) for c in consumed for inp in c.inputs + ] + + unified, canonicalize = _build_unification(op_shapes, raw_extras) + new_cost, dp = _solve_chain(unified) + old_cost = _existing_cost(consumed, fgraph, canonicalize) + + if not _provably_less(new_cost, old_cost): + continue + + visited.update(consumed) + + new_out = _build_tree(operands, dp, 0, len(operands) - 1) + old_out = top.outputs[0] + copy_stack_trace(old_out, new_out) + fgraph.replace(old_out, new_out, reason="reassociate_matmul") + + +reassociate_matmul_chain = ReassociateMatmulChain() +reassociate_matmul_chain.name = "reassociate_matmul_chain" + +optdb.register( + "reassociate_matmul_chain", + reassociate_matmul_chain, + "fast_run", + position=2.5, +) diff --git a/tests/tensor/linalg/test_products.py b/tests/tensor/linalg/test_products.py index cf601a3090..2a0ee972dd 100644 --- a/tests/tensor/linalg/test_products.py +++ b/tests/tensor/linalg/test_products.py @@ -12,7 +12,6 @@ matrix_power, pinv, ) -from pytensor.tensor.math import _allclose from pytensor.tensor.type import matrix, tensor, vector from tests import unittest_tools as utt @@ -20,19 +19,14 @@ def test_matrix_dot(): rng = np.random.default_rng(utt.fetch_seed()) n = rng.integers(4) + 2 - rs = [] - xs = [] - for k in range(n): - rs += [rng.standard_normal((4, 4)).astype(config.floatX)] - xs += [matrix()] + rs = [rng.normal(size=(4, 4)).astype(config.floatX) for _ in range(n)] + xs = [matrix() for _ in range(n)] sol = matrix_dot(*xs) pytensor_sol = function(xs, sol)(*rs) - numpy_sol = rs[0] - for r in rs[1:]: - numpy_sol = np.dot(numpy_sol, r) + numpy_sol = np.linalg.multi_dot(rs) - assert _allclose(numpy_sol, pytensor_sol) + np.testing.assert_allclose(numpy_sol, pytensor_sol) class TestMatrixPower: diff --git a/tests/tensor/rewriting/linalg/test_reassociate_matmul.py b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py new file mode 100644 index 0000000000..87e6cfa6db --- /dev/null +++ b/tests/tensor/rewriting/linalg/test_reassociate_matmul.py @@ -0,0 +1,516 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import function +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.blas import BatchedDot, Dot22, Gemm +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.math import Dot +from pytensor.tensor.rewriting.linalg.reassociate_matmul import ( + CostExpr, + _provably_less, + reassociate_matmul_chain, +) + + +requires_fast_run = pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="reassociate_matmul_chain is only registered for FAST_RUN", +) + + +def _matmul_nodes(fgraph): + """Apply nodes that perform a matmul-like contraction.""" + nodes = [] + for n in fgraph.toposort(): + op = n.op + if isinstance(op, Dot | Dot22 | Gemm | BatchedDot): + nodes.append(n) + elif isinstance(op, Blockwise) and isinstance(op.core_op, Dot): + nodes.append(n) + return nodes + + +def _matmul_output_shapes(fgraph): + """Static shapes of every matmul output (last entry is the final output).""" + return [n.outputs[0].type.shape for n in _matmul_nodes(fgraph)] + + +def _dimshuffle_count(fgraph): + return sum(1 for n in fgraph.toposort() if isinstance(n.op, DimShuffle)) + + +class TestProvablyLess: + def test_int_dominance(self): + a = CostExpr.from_dim_product([10, 20, 5]) # 1000 + b = CostExpr.from_dim_product([10, 20, 5]) + CostExpr.from_dim_product( + [1, 2, 3] + ) # 1006 + assert _provably_less(a, b) + assert not _provably_less(b, a) + assert not _provably_less(a, a) + + def test_zero_vs_positive(self): + a = CostExpr.zero() + b = CostExpr.from_dim_product([7]) + assert _provably_less(a, b) + assert not _provably_less(b, a) + assert not _provably_less(a, a) + + def test_symbolic_no_proof(self): + # m*k*n vs k*n*p: different monomials, no dominance either way. + m = pt.scalar("m") + k = pt.scalar("k") + n = pt.scalar("n") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k, n]) + b = CostExpr.from_dim_product([k, n, p]) + assert not _provably_less(a, b) + assert not _provably_less(b, a) + + def test_extra_unmatched_b_monomial(self): + # a = m*k, b = m*k + p: b strictly larger under p >= 1. + m = pt.scalar("m") + k = pt.scalar("k") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k]) + b = CostExpr.from_dim_product([m, k]) + CostExpr.from_dim_product([p]) + assert _provably_less(a, b) + assert not _provably_less(b, a) + + def test_b_dominates_with_extra_symbol(self): + # a = m*k, b = m*k*p: b is sound but not strict (p == 1 makes them equal). + m = pt.scalar("m") + k = pt.scalar("k") + p = pt.scalar("p") + a = CostExpr.from_dim_product([m, k]) + b = CostExpr.from_dim_product([m, k, p]) + assert not _provably_less(a, b) + + def test_greedy_fails_kuhn_succeeds(self): + # Greedy first-fit pairs a[0]=x with b's first dominator (x*y), leaving + # a[1]=x*y with only x*z, which doesn't dominate. The valid matching is + # x <-> x*z and x*y <-> x*y; the unmatched `w` makes the strict comparison + # succeed. + x, y, z, w = pt.scalar("x"), pt.scalar("y"), pt.scalar("z"), pt.scalar("w") + a = CostExpr.from_dim_product([x]) + CostExpr.from_dim_product([x, y]) + b = ( + CostExpr.from_dim_product([x, y]) + + CostExpr.from_dim_product([x, z]) + + CostExpr.from_dim_product([w]) + ) + assert _provably_less(a, b) + + +class TestReassociateChain: + @requires_fast_run + def test_optimal_static_ordering(self): + # (100x2) @ (2x100) @ (100x100): optimal is A @ (B @ C) with B@C producing + # a (2, 100) intermediate, far cheaper than (A @ B) @ C. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + + f = function([A, B, C], A @ B @ C) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, ( + f"Expected the optimal split to produce a (2, 100) intermediate, " + f"got {shapes}" + ) + + def test_value_equivalence_4_matrices(self): + rng = np.random.default_rng(0) + shapes = [(10, 20), (20, 5), (5, 30), (30, 3)] + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in shapes] + pt_inputs = [pt.matrix(f"M{i}", shape=s) for i, s in enumerate(shapes)] + out = pt_inputs[0] @ pt_inputs[1] @ pt_inputs[2] @ pt_inputs[3] + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 + ) + if config.mode != "FAST_COMPILE": + # Optimal split for these shapes is A @ (B @ (C @ D)) with cost 1350 + # vs. naive 3400. Pin the smallest internal intermediate as a + # sentinel that the rewriter chose this tree. + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (5, 3) in shapes_seen, shapes_seen + + def test_symbolic_shapes_preserve_user_order(self): + # When all dims are symbolic, no parenthesization is provably cheaper, so + # the rewriter must leave the user's order intact. + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.matrix("C") + out = A @ (B @ C) + f = function([A, B, C], out) + + nodes = _matmul_nodes(f.maker.fgraph) + assert len(nodes) == 2 + # The first matmul (in topo order) should be the BC contraction -- its + # inputs are exactly B and C, not the result of A @ anything. + first = nodes[0] + assert set(first.inputs) == {B, C}, ( + f"Expected first matmul to be B @ C, got inputs {first.inputs}" + ) + + def test_already_optimal_no_rewrite(self): + # User-written order is already optimal: (A @ B) -> (2, 100) @ C -> (2, 5) + # at cost 20_000 vs. the alternative A @ (B @ C) at 50_000. The rewriter + # must not disturb this. + A = pt.matrix("A", shape=(2, 100)) + B = pt.matrix("B", shape=(100, 100)) + C = pt.matrix("C", shape=(100, 5)) + f = function([A, B, C], (A @ B) @ C) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert shapes == [(2, 100), (2, 5)], ( + f"Expected the user's two-step order to be preserved exactly; got {shapes}" + ) + + def test_multi_client_breakage(self): + # If an intermediate matmul output has more than one client, the chain + # extender must stop there -- absorbing past it would duplicate the + # intermediate's computation when emitting the new tree. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + D = pt.matrix("D", shape=(100, 5)) + ab = A @ B # (100, 100), used twice + out1 = ab @ C @ D + out2 = ab.sum() + f = function([A, B, C, D], [out1, out2]) + rng = np.random.default_rng(1) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 100)).astype(config.floatX) + d_v = rng.normal(size=(100, 5)).astype(config.floatX) + r1, r2 = f(a_v, b_v, c_v, d_v) + ab_v = a_v @ b_v + np.testing.assert_allclose(r1, ab_v @ c_v @ d_v, rtol=1e-5) + np.testing.assert_allclose(r2, ab_v.sum(), rtol=1e-5) + + def test_quadratic_form_static(self): + # A @ B @ A.T with A: (50, 100), B: (100, 100). The transpose makes the + # third operand opaquely (100, 50). Both parenthesizations cost 750_000: + # (A @ B) @ A.T: 50*100*100 + 50*100*50 + # A @ (B @ A.T): 100*100*50 + 50*100*50 + # The rewriter has no provable win and must leave the user's order alone. + A = pt.matrix("A", shape=(50, 100)) + B = pt.matrix("B", shape=(100, 100)) + f = function([A, B], A @ B @ A.T) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert shapes == [(50, 100), (50, 50)], ( + f"Expected user's symmetric chain to be preserved; got {shapes}" + ) + + def test_batched_blockwise_dot_reorders(self): + # Batched chain shaped to force a reorder: A(7,100,2) B(7,2,100) C(7,100,5). + # Naive (A@B)@C: 7*100*2*100 + 7*100*100*5 = 490k. + # Optimal A@(B@C): 7*2*100*5 + 7*100*2*5 = 14k. + rng = np.random.default_rng(2) + shapes = [(7, 100, 2), (7, 2, 100), (7, 100, 5)] + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in shapes] + pt_inputs = [pt.tensor3(f"B{i}", shape=s) for i, s in enumerate(shapes)] + out = pt_inputs[0] @ pt_inputs[1] @ pt_inputs[2] + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), + np_arrays[0] @ np_arrays[1] @ np_arrays[2], + rtol=1e-5, + ) + if config.mode != "FAST_COMPILE": + shapes_seen = _matmul_output_shapes(f.maker.fgraph) + assert (7, 2, 5) in shapes_seen, shapes_seen + + @requires_fast_run + def test_dot22_chain_post_blas(self): + # By the time the rewriter runs (post-BLAS), 2-D dots may have been + # promoted to Dot22. The chain extender must treat Dot22 as a chain link + # so the chain spans the full sequence AND so the rebuilt tree retains the + # BLAS-promoted form. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 100)) + f = function([A, B, C], A @ B @ C) + nodes = _matmul_nodes(f.maker.fgraph) + ops_seen = {type(n.op).__name__ for n in nodes} + assert ops_seen & {"Dot22", "Gemm"}, ( + f"Expected BLAS-promoted nodes to survive reassociation; got {ops_seen}" + ) + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, shapes + + @pytest.mark.parametrize( + "chain_shapes,must_contain", + [ + # Classic textbook 3-matrix: A(10x30) B(30x5) C(5x60). Optimal is + # (A@B)@C with intermediate (10, 5). + ([(10, 30), (30, 5), (5, 60)], {(10, 5)}), + # 5-matrix CLRS-style chain. Optimal is (A(BC))(DE) -- intermediates + # are (35, 5), (30, 5), (5, 20), (30, 20). Two of those are the + # required-present sentinels. + ([(30, 35), (35, 15), (15, 5), (5, 10), (10, 20)], {(35, 5), (5, 20)}), + ], + ) + def test_textbook_examples(self, chain_shapes, must_contain): + rng = np.random.default_rng(3) + np_arrays = [rng.normal(size=s).astype(config.floatX) for s in chain_shapes] + pt_inputs = [pt.matrix(f"M{i}", shape=s) for i, s in enumerate(chain_shapes)] + out = pt_inputs[0] + for x in pt_inputs[1:]: + out = out @ x + f = function(pt_inputs, out) + np.testing.assert_allclose( + f(*np_arrays), np.linalg.multi_dot(np_arrays), rtol=1e-5 + ) + if config.mode != "FAST_COMPILE": + shapes = set(_matmul_output_shapes(f.maker.fgraph)) + missing = must_contain - shapes + assert not missing, ( + f"Missing expected intermediates {missing}; got {shapes}" + ) + + def test_balanced_tree_decomposition(self): + # User's explicit `(A @ B) @ (C @ D)` is a non-linear tree. The recursive + # decomposer must still see the full 4-operand chain through both subtrees. + # Shapes are chosen so the user's order has three (100, 100) contractions + # costing ~1M FLOPs and the optimal `A @ (B @ (C @ D))` avoids them + # entirely (~21k FLOPs). The unavoidable (100, 100) is the final output; + # any *internal* (100, 100) means the rewriter missed the cross-subtree + # chain. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.matrix("C", shape=(100, 3)) + D = pt.matrix("D", shape=(3, 100)) + f = function([A, B, C, D], (A @ B) @ (C @ D)) + + rng = np.random.default_rng(7) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 3)).astype(config.floatX) + d_v = rng.normal(size=(3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), (a_v @ b_v) @ (c_v @ d_v), rtol=1e-5 + ) + + if config.mode != "FAST_COMPILE": + shapes = _matmul_output_shapes(f.maker.fgraph) + n_big = sum(1 for s in shapes if s == (100, 100)) + assert n_big <= 1, ( + f"Expected at most one (100,100) (the final output); got {shapes}" + ) + + +class TestLifting: + def test_lift_expand_dims_recovers_chain(self): + # `expand_dims(A @ B, 0) @ C @ D` only sees three operands without lifting, + # and the (A @ B) intermediate is forced to (100, 100). After lifting the + # expand_dims past A @ B, the chain becomes 4-element and the DP can avoid + # the large intermediate by routing through the small m=2/k=3 dims. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.tensor3("C", shape=(1, 100, 3)) + D = pt.tensor3("D", shape=(1, 3, 100)) + ab3 = pt.expand_dims(A @ B, 0) + f = function([A, B, C, D], ab3 @ C @ D) + + rng = np.random.default_rng(11) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(1, 100, 3)).astype(config.floatX) + d_v = rng.normal(size=(1, 3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), (a_v @ b_v)[None] @ c_v @ d_v, rtol=1e-5 + ) + + if config.mode != "FAST_COMPILE": + # The unavoidable (100, 100) (or (1, 100, 100)) is the final output. + # Any internal contraction with that shape means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (1, 100, 100) not in non_final and (100, 100) not in non_final, ( + f"Lift+reorder should avoid the (100,100) intermediate; got {shapes}" + ) + + def test_lift_matrix_transpose(self): + # `(L @ R).T @ C` lifts to `R.T @ L.T @ C` (operand order swapped). Shapes + # are chosen so the lifted form's `R.T @ (L.T @ C)` is dramatically cheaper + # than the user's `(L @ R).T @ C` -- verifies that matrix-transpose lift + # exposes a real reorder, not just a no-op restructuring. + # L (100, 2), R (2, 50), C (100, 100): + # user: 100*2*50 + 50*100*100 = 510_000 + # lifted [R.T (50,2), L.T (2,100), C (100,100)] with R.T @ (L.T @ C): + # L.T @ C = 2*100*100 = 20_000; R.T @ that = 50*2*100 = 10_000. + L = pt.matrix("L", shape=(100, 2)) + R = pt.matrix("R", shape=(2, 50)) + C = pt.matrix("C", shape=(100, 100)) + f = function([L, R, C], (L @ R).T @ C) + + rng = np.random.default_rng(13) + l_v = rng.normal(size=(100, 2)).astype(config.floatX) + r_v = rng.normal(size=(2, 50)).astype(config.floatX) + c_v = rng.normal(size=(100, 100)).astype(config.floatX) + np.testing.assert_allclose(f(l_v, r_v, c_v), (l_v @ r_v).T @ c_v, rtol=1e-5) + + if config.mode != "FAST_COMPILE": + # Optimal lifted tree's intermediate is L.T @ C = (2, 100). + shapes = _matmul_output_shapes(f.maker.fgraph) + assert (2, 100) in shapes, ( + f"Expected matrix-transpose lift to enable a (2, 100) " + f"intermediate; got {shapes}" + ) + + def test_lift_atomic_gating_no_extra_nodes(self): + # All-symbolic shapes -> no provable win -> no rewrite. The lift must NOT + # add extra DimShuffle nodes when the rewrite isn't committed. Test the + # rewriter standalone on a fresh fgraph (no other passes) so the assertion + # isolates its behavior. + L = pt.matrix("L") + R = pt.matrix("R") + C = pt.matrix("C") + D = pt.matrix("D") + out = pt.expand_dims(L @ R, 0) @ pt.expand_dims(C, 0) @ pt.expand_dims(D, 0) + + fgraph = FunctionGraph([L, R, C, D], [out], clone=False) + before = _dimshuffle_count(fgraph) + reassociate_matmul_chain.apply(fgraph) + after = _dimshuffle_count(fgraph) + assert after == before, ( + f"Speculative lift leaked DimShuffle nodes: before={before} after={after}" + ) + + def test_lift_squeeze(self): + # `squeeze(A @ B, 0) @ C @ D` where (A@B) has a leading-1 batch dim. After + # the squeeze lift the chain spans 4 operands and the DP can route through + # the small k=2/k=3 dims rather than the (100, 100) intermediate the + # original chain forced. + A = pt.tensor3("A", shape=(1, 100, 2)) + B = pt.tensor3("B", shape=(1, 2, 100)) + C = pt.matrix("C", shape=(100, 3)) + D = pt.matrix("D", shape=(3, 100)) + ab2 = pt.squeeze(A @ B, axis=0) + f = function([A, B, C, D], ab2 @ C @ D) + + rng = np.random.default_rng(23) + a_v = rng.normal(size=(1, 100, 2)).astype(config.floatX) + b_v = rng.normal(size=(1, 2, 100)).astype(config.floatX) + c_v = rng.normal(size=(100, 3)).astype(config.floatX) + d_v = rng.normal(size=(3, 100)).astype(config.floatX) + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), + (a_v @ b_v).squeeze(axis=0) @ c_v @ d_v, + rtol=1e-5, + ) + if config.mode != "FAST_COMPILE": + # Final output is (100, 100); any internal one means the lift didn't fire. + shapes = _matmul_output_shapes(f.maker.fgraph) + non_final = shapes[:-1] if shapes else [] + assert (100, 100) not in non_final, ( + f"Squeeze lift should expose a chain that avoids the (100,100) " + f"intermediate; got {shapes}" + ) + + def test_lift_through_heterogeneous_ndim_does_not_crash(self): + # A DimShuffle wrapping a Blockwise(Dot) whose operands have different + # ndims (one 2-D, one 3-D) cannot have its lift propagated to both inner + # operands -- the lift's `new_order` references indices the 2-D operand + # doesn't have. The rewriter must bail and treat the wrapper as opaque + # rather than crashing or producing a malformed graph. + L = pt.matrix("L", shape=(4, 5)) + R = pt.tensor3("R", shape=(7, 5, 6)) # broadcasts L to (7, 4, 6) + C = pt.tensor3("C", shape=(7, 6, 8)) + ds_inner = pt.expand_dims(L @ R, 0) + f = function([L, R, C], ds_inner @ pt.expand_dims(C, 0)) + + rng = np.random.default_rng(29) + l_v = rng.normal(size=(4, 5)).astype(config.floatX) + r_v = rng.normal(size=(7, 5, 6)).astype(config.floatX) + c_v = rng.normal(size=(7, 6, 8)).astype(config.floatX) + np.testing.assert_allclose( + f(l_v, r_v, c_v), + (l_v @ r_v)[None] @ c_v[None], + rtol=1e-5, + ) + + def test_no_batched_dot_when_batch_dims_could_broadcast(self): + # BatchedDot does not broadcast -- its perform/C path errors at runtime + # when batch dims differ. If the chain has one operand with static batch + # dim 1 and another with non-1 batch dim, the rewriter must NOT emit + # BatchedDot for those pairs; matmul/Blockwise(Dot) is the only safe + # choice. + A = pt.tensor3("A", shape=(1, 4, 5)) + B = pt.tensor3("B", shape=(7, 5, 6)) + C = pt.tensor3("C", shape=(7, 6, 4)) + f = function([A, B, C], A @ B @ C) + + # Any BatchedDot in the rewritten graph must have both operands' static + # batch dim known and equal. Anything else is a latent runtime crash. + for node in _matmul_nodes(f.maker.fgraph): + if isinstance(node.op, BatchedDot): + lb = node.inputs[0].type.shape[0] + rb = node.inputs[1].type.shape[0] + assert lb is not None and rb is not None and lb == rb, ( + f"BatchedDot with broadcasting-incompatible batch dims: " + f"left={node.inputs[0].type.shape}, " + f"right={node.inputs[1].type.shape}" + ) + + rng = np.random.default_rng(31) + a_v = rng.normal(size=(1, 4, 5)).astype(config.floatX) + b_v = rng.normal(size=(7, 5, 6)).astype(config.floatX) + c_v = rng.normal(size=(7, 6, 4)).astype(config.floatX) + np.testing.assert_allclose(f(a_v, b_v, c_v), a_v @ b_v @ c_v, rtol=1e-5) + + def test_stacked_dimshuffle_lift_does_not_crash(self): + # `DimShuffle(DimShuffle(matmul))`: only the outer DimShuffle satisfies the + # "single-client wrapping a chain-link matmul" pattern -- the inner + # DimShuffle is not a chain link itself. The lift attempt on the outer + # must bail (treat the whole stack as opaque) and produce a correct + # graph. + A = pt.matrix("A", shape=(2, 3)) + B = pt.matrix("B", shape=(3, 4)) + C = pt.tensor4("C", shape=(1, 1, 4, 5)) + D = pt.tensor4("D", shape=(1, 1, 5, 6)) + ab_stacked = pt.expand_dims(pt.expand_dims(A @ B, 0), 0) + f = function([A, B, C, D], ab_stacked @ C @ D) + + rng = np.random.default_rng(37) + a_v = rng.normal(size=(2, 3)).astype(config.floatX) + b_v = rng.normal(size=(3, 4)).astype(config.floatX) + c_v = rng.normal(size=(1, 1, 4, 5)).astype(config.floatX) + d_v = rng.normal(size=(1, 1, 5, 6)).astype(config.floatX) + ab_stacked_v = (a_v @ b_v)[None, None] + np.testing.assert_allclose( + f(a_v, b_v, c_v, d_v), ab_stacked_v @ c_v @ d_v, rtol=1e-5 + ) + + def test_lift_blocked_by_multi_client_dimshuffle(self): + # If the DimShuffle has multiple clients, lifting would duplicate work for + # the OTHER clients. The rewriter must NOT lift in this case. + A = pt.matrix("A", shape=(100, 2)) + B = pt.matrix("B", shape=(2, 100)) + C = pt.tensor3("C", shape=(1, 100, 3)) + D = pt.tensor3("D", shape=(1, 3, 100)) + ab3 = pt.expand_dims(A @ B, 0) + out1 = ab3 @ C @ D + out2 = ab3 + f = function([A, B, C, D], [out1, out2]) + rng = np.random.default_rng(19) + a_v = rng.normal(size=(100, 2)).astype(config.floatX) + b_v = rng.normal(size=(2, 100)).astype(config.floatX) + c_v = rng.normal(size=(1, 100, 3)).astype(config.floatX) + d_v = rng.normal(size=(1, 3, 100)).astype(config.floatX) + ab3_v = (a_v @ b_v)[None] + r1, r2 = f(a_v, b_v, c_v, d_v) + np.testing.assert_allclose(r1, ab3_v @ c_v @ d_v, rtol=1e-5) + np.testing.assert_allclose(r2, ab3_v, rtol=1e-5) + # The original (A @ B) must still be in the graph because ab3 is also + # consumed by out2; if we wrongly lifted the expand_dims, A @ B would be + # gone. + nodes = _matmul_nodes(f.maker.fgraph) + assert any(n.outputs[0].type.shape == (100, 100) for n in nodes), ( + f"Expected A@B (100,100) to remain because ab3 has multiple clients; " + f"got {[n.outputs[0].type.shape for n in nodes]}" + )