Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions pytensor/tensor/rewriting/linalg/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytensor.tensor.math import Dot, outer, prod
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
Expand Down Expand Up @@ -169,6 +170,69 @@ def det_of_kronecker(fgraph, node):
return [det_final]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Dot, blockwise_of(Dot)])
def dot_of_kron(fgraph, node):
r"""Decompose ``kron(A, B) @ X`` into two matmuls.

Applies the identity :math:`(A \otimes B)\, \mathrm{vec}_{\mathrm{row}}(X)
= \mathrm{vec}_{\mathrm{row}}(A X B^\top)` column-wise across the RHS:

.. math::
(A \otimes B)\, Y
= \mathrm{reshape}\bigl(
A\, \mathrm{reshape}(Y,\, (m, p, k))\, B^\top,\,
(m p,\, k)
\bigr).

Cost drops from :math:`O(k\, (m p)^2)` to :math:`O(k\, m p\, (m + p))`,
and the :math:`(m p) \times (m p)` Kronecker matrix is never formed.
"""
K, X = node.inputs

# Peel Blockwise(Dot)'s batch-broadcast wrapper (plain expand or matrix-transposed).
transposed = False
match K.owner_op_and_inputs:
case (DimShuffle(is_left_expand_dims=True), inner):
K = inner
case (DimShuffle(is_left_expanded_matrix_transpose=True), inner):
K = inner
transposed = True

match K.owner_op_and_inputs:
case (KroneckerProduct(), A, B):
pass
case _:
return None

# ``kron(A, B).mT == kron(A.mT, B.mT)`` for 2-D ``A, B``.
if transposed:
A, B = A.mT, B.mT

if A.type.ndim != 2 or B.type.ndim != 2:
return None

m = A.shape[-1]
p = B.shape[-1]

# Bring k to the front so each (m, p) slice is a batched matmul argument.
batch_shape = tuple(X.shape[i] for i in range(X.type.ndim - 2))
k = X.shape[-1]
X_3d = X.reshape((*batch_shape, m, p, k))
X_3d = pt.moveaxis(X_3d, -1, -3)

Z = A @ X_3d
Z = Z @ B.mT

Z = pt.moveaxis(Z, -3, -1)
new_out = Z.reshape((*batch_shape, m * p, k))

copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_canonicalize
@register_stabilize
@node_rewriter([KroneckerProduct])
Expand Down
194 changes: 192 additions & 2 deletions pytensor/tensor/rewriting/linalg/solvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Container
from copy import copy
from functools import reduce

from pytensor import tensor as pt
from pytensor.assumptions import DIAGONAL, ORTHOGONAL, check_assumption
Expand All @@ -12,15 +13,18 @@
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Add
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor.basic import atleast_Nd, split
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.linalg.constructors import BlockDiagonal
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky
from pytensor.tensor.linalg.decomposition.eigen import eigh
from pytensor.tensor.linalg.decomposition.lu import lu_factor
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.products import KroneckerProduct
from pytensor.tensor.linalg.solvers.core import SolveBase
from pytensor.tensor.linalg.solvers.general import Solve, lu_solve, solve
from pytensor.tensor.linalg.solvers.linear_control import (
Expand All @@ -40,7 +44,10 @@
register_stabilize,
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg.utils import get_assume_a
from pytensor.tensor.rewriting.linalg.utils import (
get_assume_a,
is_eye_mul,
)
from pytensor.tensor.variable import TensorVariable


Expand Down Expand Up @@ -374,6 +381,189 @@ def block_diag_solve_to_block_diag_solves(fgraph, node):
return [new_out]


@register_canonicalize
@register_stabilize
@node_rewriter([blockwise_of(SolveBase)])
def solve_of_kron(fgraph, node):
r"""Decompose ``solve(kron(A, B), b)`` into per-component solves.

Inverts the identity :math:`(A \otimes B)\, \mathrm{vec}_{\mathrm{row}}(X)
= \mathrm{vec}_{\mathrm{row}}(A X B^\top)`:

.. math::
\mathrm{solve}(A \otimes B,\, b)
= \mathrm{vec}_{\mathrm{row}}\bigl(
\mathrm{solve}(B,\, \mathrm{solve}(A,\, Y)^\top)^\top
\bigr),
\quad Y = b.\mathrm{reshape}(m, p).

Applies to ``Solve``, ``SolveTriangular``, and ``CholeskySolve``. For
``Solve``, ``assume_a`` is downgraded to ``"gen"`` on the per-component
solves: :math:`A \otimes B` being symmetric / positive-definite does not
imply the same of :math:`A` and :math:`B` individually.
"""
A_kron, b = node.inputs

# Peel Blockwise's batch-broadcast wrapper (plain expand or matrix-transposed).
transposed = False
match A_kron.owner_op_and_inputs:
case (DimShuffle(is_left_expand_dims=True), inner):
A_kron = inner
case (DimShuffle(is_left_expanded_matrix_transpose=True), inner):
A_kron = inner
transposed = True

match A_kron.owner_op_and_inputs:
case (KroneckerProduct(), A, B):
pass
case _:
return None

# ``kron(A, B).mT == kron(A.mT, B.mT)`` for 2-D ``A, B``.
if transposed:
A, B = A.mT, B.mT

# KroneckerProduct broadcasts elementwise across all dims; the vec identity
# is 2-D only.
if A.type.ndim != 2 or B.type.ndim != 2:
return None

core_op = node.op.core_op
b_ndim = core_op.b_ndim

props = core_op._props_dict()
props["b_ndim"] = 2
# Fresh destroy_map: the new ops mustn't inherit overwrite flags from the
# outer Solve, which would mark unrelated tensors for in-place destruction.
props.pop("overwrite_a", None)
props.pop("overwrite_b", None)
if isinstance(core_op, Solve):
props["assume_a"] = "gen"
per_component_solve = Blockwise(type(core_op)(**props))

m = A.shape[-1]
p = B.shape[-1]

if b_ndim == 1:
# b: (..., m*p)
batch_shape = tuple(b.shape[i] for i in range(b.type.ndim - 1))
Y = b.reshape((*batch_shape, m, p))
Z = per_component_solve(A, Y)
X = per_component_solve(B, Z.mT).mT
new_out = X.reshape((*batch_shape, m * p))
else:
# b: (..., m*p, k)
batch_shape = tuple(b.shape[i] for i in range(b.type.ndim - 2))
k = b.shape[-1]
Y = b.reshape((*batch_shape, m, p, k))
# Fold (p, k) into one RHS axis so the A-solve is a single matrix solve.
Y_flat = Y.reshape((*batch_shape, m, p * k))
Z = per_component_solve(A, Y_flat).reshape((*batch_shape, m, p, k))
X = per_component_solve(B, Z)
new_out = X.reshape((*batch_shape, m * p, k))

copy_stack_trace(node.outputs[0], new_out)
return [new_out]


def _kron_plus_diag_noise_eigh_form(Ks, sigma_sq):
r"""Eigendecompose :math:`\bigotimes K_i + \sigma^2 I`.

Returns ``(Q, d)`` where :math:`Q = \bigotimes Q_i` is orthogonal,
:math:`d = (\bigotimes d_i) + \sigma^2` is the vector of eigenvalues,
and :math:`(Q_i, d_i)` is the eigendecomposition of :math:`K_i`.
"""
eigs = [eigh(K) for K in Ks]
ds = [w for (w, _) in eigs]
Qs = [v for (_, v) in eigs]

Q = reduce(pt.linalg.kron, Qs)
d = reduce(lambda a, b: pt.outer(a, b).ravel(), ds) + sigma_sq
return Q, d


@register_canonicalize
@register_stabilize
@node_rewriter([KroneckerProduct])
def solve_of_kron_plus_diag_noise(fgraph, node):
r"""Decompose ``solve(kron(*Ks) + sigma**2 * I, b)`` via per-component ``eigh``.

With :math:`K_i = Q_i \mathrm{diag}(d_i) Q_i^\top`,

.. math::
\bigl(\bigotimes K_i + \sigma^2 I\bigr)^{-1} b
= Q\, \mathrm{diag}(d)^{-1}\, Q^\top b,
\quad Q = \bigotimes Q_i,\;
d = \bigotimes d_i + \sigma^2.
"""
kron_out = node.outputs[0]

def collect_leaves(y):
if y.owner is not None and isinstance(y.owner.op, KroneckerProduct):
return collect_leaves(y.owner.inputs[0]) + collect_leaves(y.owner.inputs[1])
return [y]

Ks = collect_leaves(kron_out)
if any(K.type.ndim != 2 for K in Ks):
return None

# Find Add(kron, scalar*eye); either operand order.
K_var = sigma_sq = None
for client, kron_idx in fgraph.clients[kron_out]:
if not (
isinstance(client.op, Elemwise)
and isinstance(client.op.scalar_op, Add)
and len(client.inputs) == 2
):
continue
eye_match = is_eye_mul(client.inputs[1 - kron_idx])
if eye_match is None:
continue
_, raw_sigma_sq = eye_match
# Eigh form needs ``sigma_sq`` to broadcast against the 1-D ``d`` vector.
if not all(raw_sigma_sq.type.broadcastable):
continue
sigma_sq = raw_sigma_sq.squeeze()
K_var = client.outputs[0]
break
if K_var is None:
return None

# kron + sigma**2*I is symmetric, so left-expand and matrix-transpose wrappers
# both reach Solve consumers that are equivalent to consumers of K_var.
candidates = [K_var]
for c, idx in fgraph.clients[K_var]:
if (
idx == 0
and isinstance(c.op, DimShuffle)
and (c.op.is_left_expand_dims or c.op.is_left_expanded_matrix_transpose)
):
candidates.append(c.outputs[0])

Q = d = None # shared across all Solve consumers of the same noisy K
replacements = {}
for cand in candidates:
for client, idx in fgraph.clients[cand]:
if idx != 0:
continue
if not (
isinstance(client.op, Blockwise)
and isinstance(client.op.core_op, SolveBase)
):
continue
if Q is None:
Q, d = _kron_plus_diag_noise_eigh_form(Ks, sigma_sq)
_, b = client.inputs
b_ndim = client.op.core_op.b_ndim
b_proj = Q.mT @ b
rescaled = b_proj / d if b_ndim == 1 else b_proj / d[..., :, None]
new_out = Q @ rescaled
copy_stack_trace(client.outputs[0], new_out)
replacements[client.outputs[0]] = new_out

return replacements or None


@register_canonicalize
@register_stabilize
@node_rewriter([blockwise_of(SolveBase)])
Expand Down
Loading
Loading