diff --git a/pytensor/tensor/rewriting/linalg/products.py b/pytensor/tensor/rewriting/linalg/products.py index ef6c906184..796e89b54a 100644 --- a/pytensor/tensor/rewriting/linalg/products.py +++ b/pytensor/tensor/rewriting/linalg/products.py @@ -22,6 +22,7 @@ from pytensor.tensor.math import Dot, outer, prod from pytensor.tensor.rewriting.basic import ( register_canonicalize, + register_specialize, register_stabilize, ) from pytensor.tensor.rewriting.blockwise import blockwise_of @@ -169,6 +170,69 @@ def det_of_kronecker(fgraph, node): return [det_final] +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Dot, blockwise_of(Dot)]) +def dot_of_kron(fgraph, node): + r"""Decompose ``kron(A, B) @ X`` into two matmuls. + + Applies the identity :math:`(A \otimes B)\, \mathrm{vec}_{\mathrm{row}}(X) + = \mathrm{vec}_{\mathrm{row}}(A X B^\top)` column-wise across the RHS: + + .. math:: + (A \otimes B)\, Y + = \mathrm{reshape}\bigl( + A\, \mathrm{reshape}(Y,\, (m, p, k))\, B^\top,\, + (m p,\, k) + \bigr). + + Cost drops from :math:`O(k\, (m p)^2)` to :math:`O(k\, m p\, (m + p))`, + and the :math:`(m p) \times (m p)` Kronecker matrix is never formed. + """ + K, X = node.inputs + + # Peel Blockwise(Dot)'s batch-broadcast wrapper (plain expand or matrix-transposed). + transposed = False + match K.owner_op_and_inputs: + case (DimShuffle(is_left_expand_dims=True), inner): + K = inner + case (DimShuffle(is_left_expanded_matrix_transpose=True), inner): + K = inner + transposed = True + + match K.owner_op_and_inputs: + case (KroneckerProduct(), A, B): + pass + case _: + return None + + # ``kron(A, B).mT == kron(A.mT, B.mT)`` for 2-D ``A, B``. + if transposed: + A, B = A.mT, B.mT + + if A.type.ndim != 2 or B.type.ndim != 2: + return None + + m = A.shape[-1] + p = B.shape[-1] + + # Bring k to the front so each (m, p) slice is a batched matmul argument. + batch_shape = tuple(X.shape[i] for i in range(X.type.ndim - 2)) + k = X.shape[-1] + X_3d = X.reshape((*batch_shape, m, p, k)) + X_3d = pt.moveaxis(X_3d, -1, -3) + + Z = A @ X_3d + Z = Z @ B.mT + + Z = pt.moveaxis(Z, -3, -1) + new_out = Z.reshape((*batch_shape, m * p, k)) + + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + @register_canonicalize @register_stabilize @node_rewriter([KroneckerProduct]) diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index 33a907b1da..e51c0cd0bb 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -1,5 +1,6 @@ from collections.abc import Container from copy import copy +from functools import reduce from pytensor import tensor as pt from pytensor.assumptions import DIAGONAL, ORTHOGONAL, check_assumption @@ -12,15 +13,18 @@ node_rewriter, ) from pytensor.graph.rewriting.unify import OpPattern +from pytensor.scalar.basic import Add from pytensor.scan.op import Scan from pytensor.scan.rewriting import scan_seqopt1 from pytensor.tensor.basic import atleast_Nd, split from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.linalg.constructors import BlockDiagonal from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky +from pytensor.tensor.linalg.decomposition.eigen import eigh from pytensor.tensor.linalg.decomposition.lu import lu_factor from pytensor.tensor.linalg.inverse import MatrixInverse +from pytensor.tensor.linalg.products import KroneckerProduct from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, lu_solve, solve from pytensor.tensor.linalg.solvers.linear_control import ( @@ -40,7 +44,10 @@ register_stabilize, ) from pytensor.tensor.rewriting.blockwise import blockwise_of -from pytensor.tensor.rewriting.linalg.utils import get_assume_a +from pytensor.tensor.rewriting.linalg.utils import ( + get_assume_a, + is_eye_mul, +) from pytensor.tensor.variable import TensorVariable @@ -374,6 +381,189 @@ def block_diag_solve_to_block_diag_solves(fgraph, node): return [new_out] +@register_canonicalize +@register_stabilize +@node_rewriter([blockwise_of(SolveBase)]) +def solve_of_kron(fgraph, node): + r"""Decompose ``solve(kron(A, B), b)`` into per-component solves. + + Inverts the identity :math:`(A \otimes B)\, \mathrm{vec}_{\mathrm{row}}(X) + = \mathrm{vec}_{\mathrm{row}}(A X B^\top)`: + + .. math:: + \mathrm{solve}(A \otimes B,\, b) + = \mathrm{vec}_{\mathrm{row}}\bigl( + \mathrm{solve}(B,\, \mathrm{solve}(A,\, Y)^\top)^\top + \bigr), + \quad Y = b.\mathrm{reshape}(m, p). + + Applies to ``Solve``, ``SolveTriangular``, and ``CholeskySolve``. For + ``Solve``, ``assume_a`` is downgraded to ``"gen"`` on the per-component + solves: :math:`A \otimes B` being symmetric / positive-definite does not + imply the same of :math:`A` and :math:`B` individually. + """ + A_kron, b = node.inputs + + # Peel Blockwise's batch-broadcast wrapper (plain expand or matrix-transposed). + transposed = False + match A_kron.owner_op_and_inputs: + case (DimShuffle(is_left_expand_dims=True), inner): + A_kron = inner + case (DimShuffle(is_left_expanded_matrix_transpose=True), inner): + A_kron = inner + transposed = True + + match A_kron.owner_op_and_inputs: + case (KroneckerProduct(), A, B): + pass + case _: + return None + + # ``kron(A, B).mT == kron(A.mT, B.mT)`` for 2-D ``A, B``. + if transposed: + A, B = A.mT, B.mT + + # KroneckerProduct broadcasts elementwise across all dims; the vec identity + # is 2-D only. + if A.type.ndim != 2 or B.type.ndim != 2: + return None + + core_op = node.op.core_op + b_ndim = core_op.b_ndim + + props = core_op._props_dict() + props["b_ndim"] = 2 + # Fresh destroy_map: the new ops mustn't inherit overwrite flags from the + # outer Solve, which would mark unrelated tensors for in-place destruction. + props.pop("overwrite_a", None) + props.pop("overwrite_b", None) + if isinstance(core_op, Solve): + props["assume_a"] = "gen" + per_component_solve = Blockwise(type(core_op)(**props)) + + m = A.shape[-1] + p = B.shape[-1] + + if b_ndim == 1: + # b: (..., m*p) + batch_shape = tuple(b.shape[i] for i in range(b.type.ndim - 1)) + Y = b.reshape((*batch_shape, m, p)) + Z = per_component_solve(A, Y) + X = per_component_solve(B, Z.mT).mT + new_out = X.reshape((*batch_shape, m * p)) + else: + # b: (..., m*p, k) + batch_shape = tuple(b.shape[i] for i in range(b.type.ndim - 2)) + k = b.shape[-1] + Y = b.reshape((*batch_shape, m, p, k)) + # Fold (p, k) into one RHS axis so the A-solve is a single matrix solve. + Y_flat = Y.reshape((*batch_shape, m, p * k)) + Z = per_component_solve(A, Y_flat).reshape((*batch_shape, m, p, k)) + X = per_component_solve(B, Z) + new_out = X.reshape((*batch_shape, m * p, k)) + + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +def _kron_plus_diag_noise_eigh_form(Ks, sigma_sq): + r"""Eigendecompose :math:`\bigotimes K_i + \sigma^2 I`. + + Returns ``(Q, d)`` where :math:`Q = \bigotimes Q_i` is orthogonal, + :math:`d = (\bigotimes d_i) + \sigma^2` is the vector of eigenvalues, + and :math:`(Q_i, d_i)` is the eigendecomposition of :math:`K_i`. + """ + eigs = [eigh(K) for K in Ks] + ds = [w for (w, _) in eigs] + Qs = [v for (_, v) in eigs] + + Q = reduce(pt.linalg.kron, Qs) + d = reduce(lambda a, b: pt.outer(a, b).ravel(), ds) + sigma_sq + return Q, d + + +@register_canonicalize +@register_stabilize +@node_rewriter([KroneckerProduct]) +def solve_of_kron_plus_diag_noise(fgraph, node): + r"""Decompose ``solve(kron(*Ks) + sigma**2 * I, b)`` via per-component ``eigh``. + + With :math:`K_i = Q_i \mathrm{diag}(d_i) Q_i^\top`, + + .. math:: + \bigl(\bigotimes K_i + \sigma^2 I\bigr)^{-1} b + = Q\, \mathrm{diag}(d)^{-1}\, Q^\top b, + \quad Q = \bigotimes Q_i,\; + d = \bigotimes d_i + \sigma^2. + """ + kron_out = node.outputs[0] + + def collect_leaves(y): + if y.owner is not None and isinstance(y.owner.op, KroneckerProduct): + return collect_leaves(y.owner.inputs[0]) + collect_leaves(y.owner.inputs[1]) + return [y] + + Ks = collect_leaves(kron_out) + if any(K.type.ndim != 2 for K in Ks): + return None + + # Find Add(kron, scalar*eye); either operand order. + K_var = sigma_sq = None + for client, kron_idx in fgraph.clients[kron_out]: + if not ( + isinstance(client.op, Elemwise) + and isinstance(client.op.scalar_op, Add) + and len(client.inputs) == 2 + ): + continue + eye_match = is_eye_mul(client.inputs[1 - kron_idx]) + if eye_match is None: + continue + _, raw_sigma_sq = eye_match + # Eigh form needs ``sigma_sq`` to broadcast against the 1-D ``d`` vector. + if not all(raw_sigma_sq.type.broadcastable): + continue + sigma_sq = raw_sigma_sq.squeeze() + K_var = client.outputs[0] + break + if K_var is None: + return None + + # kron + sigma**2*I is symmetric, so left-expand and matrix-transpose wrappers + # both reach Solve consumers that are equivalent to consumers of K_var. + candidates = [K_var] + for c, idx in fgraph.clients[K_var]: + if ( + idx == 0 + and isinstance(c.op, DimShuffle) + and (c.op.is_left_expand_dims or c.op.is_left_expanded_matrix_transpose) + ): + candidates.append(c.outputs[0]) + + Q = d = None # shared across all Solve consumers of the same noisy K + replacements = {} + for cand in candidates: + for client, idx in fgraph.clients[cand]: + if idx != 0: + continue + if not ( + isinstance(client.op, Blockwise) + and isinstance(client.op.core_op, SolveBase) + ): + continue + if Q is None: + Q, d = _kron_plus_diag_noise_eigh_form(Ks, sigma_sq) + _, b = client.inputs + b_ndim = client.op.core_op.b_ndim + b_proj = Q.mT @ b + rescaled = b_proj / d if b_ndim == 1 else b_proj / d[..., :, None] + new_out = Q @ rescaled + copy_stack_trace(client.outputs[0], new_out) + replacements[client.outputs[0]] = new_out + + return replacements or None + + @register_canonicalize @register_stabilize @node_rewriter([blockwise_of(SolveBase)]) diff --git a/pytensor/tensor/rewriting/linalg/summary.py b/pytensor/tensor/rewriting/linalg/summary.py index d702a064cb..0b3b43db3f 100644 --- a/pytensor/tensor/rewriting/linalg/summary.py +++ b/pytensor/tensor/rewriting/linalg/summary.py @@ -1,3 +1,5 @@ +from functools import reduce + import numpy as np from pytensor import tensor as pt @@ -10,22 +12,27 @@ copy_stack_trace, node_rewriter, ) -from pytensor.scalar.basic import Abs, Exp, Log, Sign, Sqr +from pytensor.scalar.basic import Abs, Add, Exp, Log, Sign, Sqr from pytensor.tensor.basic import ones from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.linalg.decomposition.cholesky import Cholesky +from pytensor.tensor.linalg.decomposition.eigen import eigh from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor from pytensor.tensor.linalg.decomposition.qr import QR from pytensor.tensor.linalg.decomposition.svd import SVD -from pytensor.tensor.linalg.summary import SLogDet, det +from pytensor.tensor.linalg.products import KroneckerProduct +from pytensor.tensor.linalg.summary import Det, SLogDet, det from pytensor.tensor.math import Prod, log, prod from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, ) -from pytensor.tensor.rewriting.linalg.utils import matrix_diagonal_product +from pytensor.tensor.rewriting.linalg.utils import ( + is_eye_mul, + matrix_diagonal_product, +) @register_stabilize @@ -192,6 +199,80 @@ def det_of_triangular(fgraph, node): return [det_val] +@register_canonicalize +@register_stabilize +@node_rewriter([KroneckerProduct]) +def det_of_kron_plus_diag_noise(fgraph, node): + r"""Rewrite ``det(kron(*Ks) + sigma**2 * I)`` as ``prod(d)`` via per-component ``eigh``. + + The eigenvalues of the noisy Kron are :math:`d = (\bigotimes d_i) + + \sigma^2` where :math:`d_i` are the eigenvalues of :math:`K_i`, so + :math:`\det K = \prod_j d_j`. + """ + kron_out = node.outputs[0] + + def collect_leaves(y): + if y.owner is not None and isinstance(y.owner.op, KroneckerProduct): + return collect_leaves(y.owner.inputs[0]) + collect_leaves(y.owner.inputs[1]) + return [y] + + Ks = collect_leaves(kron_out) + if any(K.type.ndim != 2 for K in Ks): + return None + + # Find Add(kron, scalar*eye); either operand order. + K_var = sigma_sq = None + for client, kron_idx in fgraph.clients[kron_out]: + if not ( + isinstance(client.op, Elemwise) + and isinstance(client.op.scalar_op, Add) + and len(client.inputs) == 2 + ): + continue + eye_match = is_eye_mul(client.inputs[1 - kron_idx]) + if eye_match is None: + continue + _, raw_sigma_sq = eye_match + # Eigh form needs ``sigma_sq`` to broadcast against the 1-D ``d`` vector. + if not all(raw_sigma_sq.type.broadcastable): + continue + sigma_sq = raw_sigma_sq.squeeze() + K_var = client.outputs[0] + break + if K_var is None: + return None + + # kron + sigma**2*I is symmetric, so left-expand and matrix-transpose wrappers + # both reach Det consumers that are equivalent to consumers of K_var. + candidates = [K_var] + for c, idx in fgraph.clients[K_var]: + if ( + idx == 0 + and isinstance(c.op, DimShuffle) + and (c.op.is_left_expand_dims or c.op.is_left_expanded_matrix_transpose) + ): + candidates.append(c.outputs[0]) + + d = None # shared across all Det consumers of the same noisy K + replacements = {} + for cand in candidates: + for client, idx in fgraph.clients[cand]: + if idx != 0: + continue + if not ( + isinstance(client.op, Blockwise) and isinstance(client.op.core_op, Det) + ): + continue + if d is None: + ds = [eigh(K)[0] for K in Ks] + d = reduce(lambda a, b: pt.outer(a, b).ravel(), ds) + sigma_sq + out = d.prod().astype(client.outputs[0].type.dtype) + copy_stack_trace(client.outputs[0], out) + replacements[client.outputs[0]] = out + + return replacements or None + + @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): diff --git a/tests/tensor/rewriting/linalg/test_products.py b/tests/tensor/rewriting/linalg/test_products.py index 3b1067e069..a4289a0b65 100644 --- a/tests/tensor/rewriting/linalg/test_products.py +++ b/tests/tensor/rewriting/linalg/test_products.py @@ -262,6 +262,55 @@ def test_expm_of_diag(make_diag): assert_equal_computations([rewritten], [expected]) +@pytest.mark.parametrize("transposed", [False, True], ids=["plain", "transposed"]) +@pytest.mark.parametrize( + "rhs_kind", + ["vector", "matrix", "batched_matrix"], +) +def test_dot_of_kron(rhs_kind, transposed): + """``kron(A, B) @ X`` decomposes into two matmuls via the vec identity.""" + rng = np.random.default_rng(0) + m, p, k, batch = 4, 3, 5, 2 + + A = pt.matrix("A", shape=(m, m)) + B = pt.matrix("B", shape=(p, p)) + A_v = rng.normal(size=(m, m)) + B_v = rng.normal(size=(p, p)) + + if rhs_kind == "vector": + X = pt.vector("X", shape=(m * p,)) + X_v = rng.normal(size=(m * p,)) + elif rhs_kind == "matrix": + X = pt.matrix("X", shape=(m * p, k)) + X_v = rng.normal(size=(m * p, k)) + else: # batched_matrix + X = pt.tensor("X", shape=(batch, m * p, k)) + X_v = rng.normal(size=(batch, m * p, k)) + + K = pt.linalg.kron(A, B) + K_v = np.kron(A_v, B_v) + if transposed: + K = K.T + K_v = K_v.T + out = K @ X + if rhs_kind == "batched_matrix": + expected = np.stack([K_v @ X_v[i] for i in range(batch)]) + else: + expected = K_v @ X_v + + f = function([A, B, X], out, mode="FAST_RUN") + + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert not any(isinstance(op, KroneckerProduct) for op in ops) + + assert_allclose( + f(A_v, B_v, X_v), + expected, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + def test_kron_of_diagonal_to_diagonal(): da = pt.tensor("da", shape=(3, 3)) db = pt.tensor("db", shape=(4, 4)) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index 3dd70ae394..8f113453d5 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -17,6 +17,7 @@ from pytensor.tensor.linalg.constructors import BlockDiagonal from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import LUFactor +from pytensor.tensor.linalg.products import KroneckerProduct from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, solve from pytensor.tensor.linalg.solvers.linear_control import ( @@ -33,6 +34,7 @@ scan_split_non_sequence_decomposition_and_solve, ) from pytensor.tensor.type import matrix, tensor +from tests import unittest_tools as utt from tests.unittest_tools import assert_equal_computations @@ -563,6 +565,289 @@ def test_block_diag_solve_pushdown_both_sides_block_diag(): assert_allclose(f(A1_v, A2_v, B1_v, B2_v), expected, atol=1e-10) +class TestSolveOfKron: + """Tests for the ``solve(kron(A, B), b)`` vec-identity rewrite.""" + + @staticmethod + def _assert_no_kron(f): + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert not any(isinstance(op, KroneckerProduct) for op in ops), ( + f"KroneckerProduct still in graph: {[type(op).__name__ for op in ops]}" + ) + + @pytest.mark.parametrize( + "assume_a, expected_op, transposed", + [ + ("gen", Solve, False), + ("gen", Solve, True), + ("lower triangular", SolveTriangular, False), + ("pos", CholeskySolve, False), + ], + ids=["gen-plain", "gen-transposed", "lower-tri", "pos"], + ) + @pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") + def test_solve_of_kron(self, b_ndim, assume_a, expected_op, transposed): + rng = np.random.default_rng(0) + m, p, k = 4, 3, 5 + + # Symbolically impose the structure ``assume_a`` requires so finite-difference + # perturbations stay within the contract (mirrors test_cholesky's ``r.dot(r.T)`` + # pattern for verify_grad against assume_a-dependent rewrites). + def shape_component(A, B): + if assume_a == "lower triangular": + return pt.tril(A), pt.tril(B) + if assume_a == "pos": + A = A @ A.mT + 1e-3 * pt.eye(m, dtype=A.type.dtype) + B = B @ B.mT + 1e-3 * pt.eye(p, dtype=B.type.dtype) + return assume(A, positive_definite=True), assume( + B, positive_definite=True + ) + return A, B + + def shape_component_np(A_v, B_v): + if assume_a == "lower triangular": + return np.tril(A_v), np.tril(B_v) + if assume_a == "pos": + return A_v @ A_v.T + 1e-3 * np.eye(m), B_v @ B_v.T + 1e-3 * np.eye(p) + return A_v, B_v + + def build_out(A, B, b): + A, B = shape_component(A, B) + K = pt.linalg.kron(A, B) + if transposed: + K = K.T + return solve(K, b, assume_a=assume_a, b_ndim=b_ndim) + + A_pt = pt.matrix("A", shape=(m, m)) + B_pt = pt.matrix("B", shape=(p, p)) + if b_ndim == 1: + b_pt = pt.vector("b", shape=(m * p,)) + b_v = rng.normal(size=(m * p,)) + else: + b_pt = pt.matrix("b", shape=(m * p, k)) + b_v = rng.normal(size=(m * p, k)) + + out = build_out(A_pt, B_pt, b_pt) + f = function([A_pt, B_pt, b_pt], out, mode="FAST_RUN") + self._assert_no_kron(f) + + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert any(isinstance(op, expected_op) for op in ops) + if expected_op is not Solve: + # A specialized op must not be lifted back to a general Solve. + assert not any(isinstance(op, Solve) for op in ops) + + A_v = rng.normal(size=(m, m)) + B_v = rng.normal(size=(p, p)) + A_shaped, B_shaped = shape_component_np(A_v, B_v) + K_v = np.kron(A_shaped, B_shaped) + if transposed: + K_v = K_v.T + expected = np.linalg.solve(K_v, b_v) + assert_allclose(f(A_v, B_v, b_v), expected, atol=1e-9) + + utt.verify_grad(build_out, [A_v, B_v, b_v], rng=rng) + + def test_n_way_kron(self): + """``kron(kron(A, B), C)`` recursively decomposes to three component solves.""" + rng = np.random.default_rng(3) + m, p, q = 3, 2, 4 + A = pt.matrix("A", shape=(m, m)) + B = pt.matrix("B", shape=(p, p)) + C = pt.matrix("C", shape=(q, q)) + b = pt.vector("b", shape=(m * p * q,)) + + K = pt.linalg.kron(pt.linalg.kron(A, B), C) + out = solve(K, b) + f = function([A, B, C, b], out, mode="FAST_RUN") + self._assert_no_kron(f) + + A_v = rng.normal(size=(m, m)) + B_v = rng.normal(size=(p, p)) + C_v = rng.normal(size=(q, q)) + b_v = rng.normal(size=(m * p * q,)) + expected = np.linalg.solve(np.kron(np.kron(A_v, B_v), C_v), b_v) + assert_allclose(f(A_v, B_v, C_v, b_v), expected, atol=1e-9) + + def test_batched_b(self): + """``b`` carries a leading batch axis the component solves broadcast over.""" + rng = np.random.default_rng(4) + m, p, batch = 3, 4, 5 + A = pt.matrix("A", shape=(m, m)) + B = pt.matrix("B", shape=(p, p)) + b = pt.tensor("b", shape=(batch, m * p)) + + out = solve(pt.linalg.kron(A, B), b, b_ndim=1) + f = function([A, B, b], out, mode="FAST_RUN") + self._assert_no_kron(f) + + A_v = rng.normal(size=(m, m)) + B_v = rng.normal(size=(p, p)) + b_v = rng.normal(size=(batch, m * p)) + K_v = np.kron(A_v, B_v) + expected = np.stack([np.linalg.solve(K_v, b_v[i]) for i in range(batch)]) + assert_allclose(f(A_v, B_v, b_v), expected, atol=1e-9) + + def test_pos_assume_a_downgrades_to_gen(self): + """``Solve(assume_a='pos')`` of a kron must not propagate ``pos`` to components.""" + m, p = 3, 2 + A = pt.matrix("A", shape=(m, m)) + B = pt.matrix("B", shape=(p, p)) + b = pt.vector("b", shape=(m * p,)) + + out = solve(pt.linalg.kron(A, B), b, assume_a="pos") + f = function([A, B, b], out, mode="FAST_RUN") + self._assert_no_kron(f) + + # No component Solve should carry assume_a='pos' (would be unsound). + for n in f.maker.fgraph.toposort(): + core = getattr(n.op, "core_op", n.op) + if isinstance(core, Solve): + assert core.assume_a != "pos" + + +class TestSolveOfKronPlusDiagNoise: + """Tests for the noisy-Kron eigh-decomposition solve rewrite.""" + + @staticmethod + def _assert_no_kron_no_full_solve(f): + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert not any(isinstance(op, KroneckerProduct) for op in ops), ( + f"KroneckerProduct still in graph: {[type(op).__name__ for op in ops]}" + ) + # Full Solve / CholeskySolve on the N x N matrix would defeat the point. + assert not any(isinstance(op, Solve | CholeskySolve) for op in ops), ( + f"Full-N solve survived: {[type(op).__name__ for op in ops]}" + ) + + @staticmethod + def _random_pd(rng, n): + a = rng.normal(size=(n, n)) + return a @ a.T + np.eye(n) + + @pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") + def test_solve(self, b_ndim): + rng = np.random.default_rng(0) + m, p, k = 4, 3, 5 + + # Symbolically enforce PD on each component so verify_grad's FD + # perturbations stay inside the rewrite's contract. + def build_out(K1_raw, K2_raw, sigma, b): + K1 = K1_raw @ K1_raw.mT + pt.eye(m) + K2 = K2_raw @ K2_raw.mT + pt.eye(p) + K = pt.linalg.kron(K1, K2) + sigma**2 * pt.eye(m * p) + return solve(K, b, assume_a="pos", b_ndim=b_ndim) + + K1_pt = pt.matrix("K1", shape=(m, m)) + K2_pt = pt.matrix("K2", shape=(p, p)) + sigma_pt = pt.scalar("sigma") + if b_ndim == 1: + b_pt = pt.vector("b", shape=(m * p,)) + b_v = rng.normal(size=(m * p,)) + else: + b_pt = pt.matrix("b", shape=(m * p, k)) + b_v = rng.normal(size=(m * p, k)) + + out = build_out(K1_pt, K2_pt, sigma_pt, b_pt) + f = function([K1_pt, K2_pt, sigma_pt, b_pt], out, mode="FAST_RUN") + self._assert_no_kron_no_full_solve(f) + + K1_v = rng.normal(size=(m, m)) + K2_v = rng.normal(size=(p, p)) + sigma_v = 0.7 + K_v = np.kron( + K1_v @ K1_v.T + np.eye(m), K2_v @ K2_v.T + np.eye(p) + ) + sigma_v**2 * np.eye(m * p) + expected = np.linalg.solve(K_v, b_v) + assert_allclose(f(K1_v, K2_v, sigma_v, b_v), expected, atol=1e-9) + + utt.verify_grad(build_out, [K1_v, K2_v, sigma_v, b_v], rng=rng) + + def test_commutative_add(self): + """``sigma**2*I + kron`` matches the same as ``kron + sigma**2*I``.""" + m, p = 4, 3 + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + sigma = pt.scalar("sigma") + b = pt.vector("b", shape=(m * p,)) + + K = sigma**2 * pt.eye(m * p) + pt.linalg.kron(K1, K2) + out = solve(K, b, assume_a="pos") + f = function([K1, K2, sigma, b], out, mode="FAST_RUN") + self._assert_no_kron_no_full_solve(f) + + def test_n_way(self): + """``kron(K1, kron(K2, K3)) + sigma**2*I`` flattens recursively.""" + from pytensor.tensor.linalg.decomposition.eigen import Eigh + + rng = np.random.default_rng(2) + m, p, q = 3, 2, 4 + N = m * p * q + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + K3 = pt.matrix("K3", shape=(q, q)) + sigma = pt.scalar("sigma") + b = pt.vector("b", shape=(N,)) + + K = pt.linalg.kron(pt.linalg.kron(K1, K2), K3) + sigma**2 * pt.eye(N) + out = solve(K, b, assume_a="pos") + f = function([K1, K2, K3, sigma, b], out, mode="FAST_RUN") + self._assert_no_kron_no_full_solve(f) + + # One eigh per component. + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert sum(1 for op in ops if isinstance(op, Eigh)) == 3 + + K1_v = self._random_pd(rng, m) + K2_v = self._random_pd(rng, p) + K3_v = self._random_pd(rng, q) + sigma_v = 0.5 + b_v = rng.normal(size=(N,)) + K_v = np.kron(np.kron(K1_v, K2_v), K3_v) + sigma_v**2 * np.eye(N) + expected = np.linalg.solve(K_v, b_v) + assert_allclose(f(K1_v, K2_v, K3_v, sigma_v, b_v), expected, atol=1e-9) + + def test_non_uniform_noise_does_not_fire(self): + """Non-scalar noise must leave the kron in place.""" + m, p = 4, 3 + N = m * p + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + noise = pt.vector("noise", shape=(N,)) + b = pt.vector("b", shape=(N,)) + + K = pt.linalg.kron(K1, K2) + noise[:, None] * pt.eye(N) + out = solve(K, b, assume_a="pos") + f = function([K1, K2, noise, b], out, mode="FAST_RUN") + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert any(isinstance(op, KroneckerProduct) for op in ops) + + def test_negative_sigma_sq_still_correct(self): + """The rewrite is algebraic: a negative coefficient that keeps K PD works.""" + rng = np.random.default_rng(3) + m, p = 4, 3 + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + alpha = pt.scalar("alpha") + b = pt.vector("b", shape=(m * p,)) + + # K = kron + alpha * I, where alpha is negative but small enough to keep K PD. + K = pt.linalg.kron(K1, K2) + alpha * pt.eye(m * p) + out = solve(K, b, assume_a="pos") + f = function([K1, K2, alpha, b], out, mode="FAST_RUN") + self._assert_no_kron_no_full_solve(f) + + K1_v = self._random_pd(rng, m) + m * np.eye(m) + K2_v = self._random_pd(rng, p) + p * np.eye(p) + alpha_v = -0.1 + b_v = rng.normal(size=(m * p,)) + K_v = np.kron(K1_v, K2_v) + alpha_v * np.eye(m * p) + # Sanity check: K is still PD. + assert np.linalg.eigvalsh(K_v).min() > 0 + expected = np.linalg.solve(K_v, b_v) + assert_allclose(f(K1_v, K2_v, alpha_v, b_v), expected, atol=1e-9) + + class TestDiagonalSolveToDivision: @pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") @pytest.mark.parametrize( diff --git a/tests/tensor/rewriting/linalg/test_summary.py b/tests/tensor/rewriting/linalg/test_summary.py index 8f61da491c..14a9483f52 100644 --- a/tests/tensor/rewriting/linalg/test_summary.py +++ b/tests/tensor/rewriting/linalg/test_summary.py @@ -9,6 +9,7 @@ from pytensor.graph import rewrite_graph from pytensor.tensor.linalg.decomposition import lu, qr, svd from pytensor.tensor.linalg.decomposition.cholesky import cholesky +from pytensor.tensor.linalg.products import KroneckerProduct from pytensor.tensor.linalg.summary import Det, SLogDet, det from pytensor.tensor.type import matrix from tests.unittest_tools import assert_equal_computations @@ -459,3 +460,68 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn): expected = expected_fn(x) rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) assert_equal_computations([rewritten], [expected]) + + +class TestDetOfKronPlusDiagNoise: + """Tests for the noisy-Kron eigh-decomposition det rewrite.""" + + @staticmethod + def _assert_no_kron_no_full_det(f): + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert not any(isinstance(op, KroneckerProduct) for op in ops) + assert not any(isinstance(op, Det | SLogDet) for op in ops) + + @staticmethod + def _random_pd(rng, n): + a = rng.normal(size=(n, n)) + return a @ a.T + np.eye(n) + + def test_slogdet_logp(self): + """``slogdet`` chains through to ``sum(log|d|)`` (the logp use case).""" + m, p = 4, 3 + N = m * p + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + sigma = pt.scalar("sigma") + K = pt.linalg.kron(K1, K2) + sigma**2 * pt.eye(N) + _, log_det = pt.linalg.slogdet(K) + f = function([K1, K2, sigma], log_det, mode="FAST_RUN") + self._assert_no_kron_no_full_det(f) + + rng = np.random.default_rng(0) + K1_v = self._random_pd(rng, m) + K2_v = self._random_pd(rng, p) + sigma_v = 0.7 + K_v = np.kron(K1_v, K2_v) + sigma_v**2 * np.eye(N) + expected = np.linalg.slogdet(K_v)[1] + got = f(K1_v, K2_v, sigma_v) + assert_allclose( + got, + expected, + atol=1e-3 if config.floatX == "float32" else 1e-9, + rtol=1e-3 if config.floatX == "float32" else 1e-9, + ) + + def test_n_way(self): + m, p, q = 3, 2, 4 + N = m * p * q + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + K3 = pt.matrix("K3", shape=(q, q)) + sigma = pt.scalar("sigma") + K = pt.linalg.kron(pt.linalg.kron(K1, K2), K3) + sigma**2 * pt.eye(N) + _, log_det = pt.linalg.slogdet(K) + f = function([K1, K2, K3, sigma], log_det, mode="FAST_RUN") + self._assert_no_kron_no_full_det(f) + + def test_non_uniform_noise_does_not_fire(self): + m, p = 4, 3 + N = m * p + K1 = pt.matrix("K1", shape=(m, m)) + K2 = pt.matrix("K2", shape=(p, p)) + noise = pt.vector("noise", shape=(N,)) + K = pt.linalg.kron(K1, K2) + noise[:, None] * pt.eye(N) + log_det = pt.linalg.slogdet(K)[1] + f = function([K1, K2, noise], log_det, mode="FAST_RUN") + ops = [getattr(n.op, "core_op", n.op) for n in f.maker.fgraph.toposort()] + assert any(isinstance(op, KroneckerProduct) for op in ops)