diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..b4ad71a85 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,6 @@ # Releases + ## 0.9.7.dev0 This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. @@ -12,8 +13,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +<<<<<<< HEAD - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +======= +- Add cost functions between linear operators following + [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920) (PR #792) +>>>>>>> 8d13c55 (edits as per PR #792) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 0f7025ec1..2452dafca 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -34,6 +34,7 @@ API and modules plot regpath sliced + sgot smooth stochastic unbalanced diff --git a/examples/plot_sgot.py b/examples/plot_sgot.py new file mode 100644 index 000000000..609279941 --- /dev/null +++ b/examples/plot_sgot.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +========================================= +SGOT example for a rotated linear system +========================================= + +This notebook presents a synthetic example of Spectral Grassmannian-Wasserstein +Optimal Transport (SGOT) on linear dynamical systems. + +We consider a signal formed by the sum of two damped oscillatory modes evolving +along a rotated direction in the plane. The signal is then associated with an +underlying continuous linear dynamical system, and we study how its spectral +representation varies under rotation. The SGOT cost and metric are used to +compare the reference and rotated systems. + +[1] T. Germain; R. Flamary; V. R. Kostic; K. Lounici, A Spectral-Grassmann Wasserstein Metric for Operator Representations of Dynamical Systems, arXiv preprint arXiv:2509.24920, 2025. + +""" + +import numpy as np +import matplotlib.pyplot as plt + +from ot.sgot import sgot_metric, sgot_cost_matrix + +from scipy.linalg import eig + + +# sampling parameters and time grid +fs = 50 +max_t = 5 +time = np.linspace(0, max_t, fs * max_t) +dt = 1 / fs + + +# %% +# Example: rotating a linear dynamical system in 3D +# ------------------------------------------------- +# +# 1. Build a simple observed signal +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We begin by assuming that the observed signal is made of two oscillatory +# components: +# +# .. math:: +# +# x(t)=e^{-\tau_1 t}\cos(2\pi\omega_1 t)\,\vec e(\theta) +# \;+\; +# e^{-\tau_2 t}\cos(2\pi\omega_2 t)\,\vec e(\theta), +# +# where :math:`\vec e(\theta)\in\mathbb{R}^2` is a fixed real vector. Thus, +# :math:`x(t)` evolves along the one-dimensional subspace spanned by +# :math:`\vec e(\theta)`, while its time dependence exhibits oscillatory and +# dissipative behaviour. + +tau_0 = np.array([0.08, 0.18]) +freq_0 = np.array([1.0, 2.0]) +theta_0 = np.pi / 4 + + +def generate_data(time, tau, freq, theta): + t_ = np.sin(2 * np.pi * freq[None, :] * time[:, None]) * np.exp( + -tau[None, :] * time[:, None] + ) + t_ = t_.sum(axis=1) + traj_0 = np.zeros((t_.shape[0], 2)) + traj_0[:, 0] = t_ + rotation_matrix = np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ) + traj_0 = traj_0 @ rotation_matrix.T + return traj_0 + + +traj_0 = generate_data(time, tau_0, freq_0, theta_0) + + +# plot the observed signal components and their sum +plt.figure(figsize=(10, 4)) +plt.plot(time, traj_0, label="base trajectory", linewidth=2) +plt.xlabel("time") +plt.ylabel("amplitude") +plt.legend() +plt.title(r"Observed scalar signal along $\vec{e}(\theta)$") +plt.show() + + +# %% +# 2. Interpret the signal as coming from a continuous linear dynamical system +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We assume that :math:`x(t)` is generated by an underlying continuous linear +# dynamical system. Since the observed signal is a superposition of two +# sinusoidal modes, the corresponding linear dynamics are naturally described +# by a fourth-order model. We therefore introduce the state vector +# +# .. math:: +# +# z(t)= +# \begin{pmatrix} +# x_1(t)\\ +# x_2(t)\\ +# \vdots\\ +# x_1^{(3)}(t)\\ +# x_2^{(3)}(t) +# \end{pmatrix} +# \in\mathbb{R}^8. +# +# This allows us to rewrite the dynamics as a first-order linear system: +# +# .. math:: +# +# \dot{z}(t)=Az(t), +# +# where :math:`A\in\mathbb{R}^{8\times 8}`. Its solution is then given by +# +# .. math:: +# +# z(t)=e^{tA}z_0. + +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(projection="3d") + +ax.plot(time, traj_0[:, 0], traj_0[:, 1]) +ax.set_xlabel("time") +ax.set_ylabel("x₁(t)") +ax.set_title("Observed trajectory in time") + +ax.text2D(1.08, 0.5, "x₂(t)", transform=ax.transAxes, rotation=90, va="center") + +plt.show() + + +# %% +# 3. Sampling and preprocessing discrete trajectories of the dynamical system +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We now have a bridge between the continuous system and the operator we later +# aim to infer from sampled data. Since in practice we do not observe the full +# continuous trajectory, we work instead with discrete samples of the signal. +# We take snapshots at uniform time intervals :math:`\Delta t`, and write the +# sampled signal as +# +# .. math:: +# +# S= +# \begin{pmatrix} +# x_1(0) & x_2(0)\\ +# x_1(\Delta t) & x_2(\Delta t)\\ +# \vdots\\ +# x_1(N\Delta t) & x_2(N\Delta t)\\ +# \end{pmatrix} +# +# The goal is now to use these observations to recover the operator governing +# the evolution. To do this, we augment the signal :math:`s` using a sliding +# window of length :math:`w`. For each :math:`k`, define +# +# .. math:: +# +# z_k = +# \begin{pmatrix} +# s_k\\ +# s_{k+1}\\ +# \vdots\\ +# s_{k+w-1} +# \end{pmatrix} +# +# We then form the data matrices +# +# .. math:: +# +# X= +# \begin{pmatrix} +# z_1\\ +# z_2\\ +# \vdots\\ +# z_{N-w} +# \end{pmatrix}, +# \qquad +# Y= +# \begin{pmatrix} +# z_2\\ +# z_3\\ +# \vdots\\ +# z_{N-w+1} +# \end{pmatrix}, +# +# so that :math:`X` contains the present windowed states and :math:`Y` the +# corresponding shifted future states. + + +# build a 4-dimensional state using delay embedding +def augment(traj, window_length=2): + Z = np.lib.stride_tricks.sliding_window_view(traj, (window_length, 1)) + Z = Z.reshape(Z.shape[0], -1) + return Z + + +# create the embedded state matrix Z +Z = augment(traj_0[:, [0]], 4) +Z.shape + +# inspect one embedded state vector +Z[0] + +# create X and Y for the SGOT metric +X = Z[:-1] +Y = Z[1:] + +# inspect shapes of X and Y +print("X shape:", X.shape) +print("Y shape:", Y.shape) + + +# %% +# 4. Estimate the discrete-time operator +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We now identify the operator that maps :math:`X` to :math:`Y`. From +# +# .. math:: +# +# \dot{z}=Az, +# +# we have +# +# .. math:: +# +# z(t+\Delta t)=e^{\Delta tA}z(t). +# +# Setting +# +# .. math:: +# +# B=e^{\Delta tA}, +# +# the corresponding discrete-time evolution is governed by :math:`B`, and we +# seek the best linear map satisfying +# +# .. math:: +# +# Y\approx X B^T. +# +# Equivalently, we solve the optimisation problem +# +# .. math:: +# +# \min_B \|Y-XB\|^2. +# +# We want to recover the best rank-:math:`r` operator, whose estimator is +# defined as follows: +# +# .. math:: +# +# B = C_{xx}^{-\frac{1}{2}}[C_{xx}^{-\frac{1}{2}}C_{xy}]_r +# \quad \text{s.t} \quad C_{xx} = X^T X \quad \text{and} \quad C_{xy} = X^TY. +# +# Here :math:`[\cdot]_r` denotes the best rank-:math:`r` estimator obtained via +# SVD decomposition. [2] +# +# [2] Kostic, V., Novelli, P., Maurer, A., Ciliberto, C., Rosasco, L. and +# Pontil, M., 2022. Learning dynamical systems via Koopman operator regression +# in reproducing kernel Hilbert spaces. Advances in Neural Information +# Processing Systems, 35, pp.4017-4031. + + +def estimator(X, Y, rank=4): + # X: (n_samples, n_features) + # Y: (n_samples, n_features) + + # estimate operator + cxx = X.T @ X + U, S, Vt = np.linalg.svd(cxx) + S_inv = np.divide(1, S, out=np.zeros_like(S), where=S != 0) + cxx_inv_half = Vt.T @ np.diag(np.sqrt(S_inv)) @ U.T + cxy = X.T @ Y + T = cxx_inv_half @ cxy + U, S, Vt = np.linalg.svd(T) + S[rank:] = 0 + T_rank = U @ np.diag(S) @ Vt + T = cxx_inv_half @ T_rank + + # estimate spectral decomposition + val, vl, vr = eig(T, left=True, right=True) + sort_idx = np.argsort(np.abs(val))[::-1] + val = val[sort_idx][:rank] + vl = vl[:, sort_idx][:, :rank] + vr = vr[:, sort_idx][:, :rank] + + return T, {"eig_val": val, "eig_vec_left": vl, "eig_vec_right": vr} + + +B_0, B_0_spec = estimator(X, Y, rank=4) +Y_pred = X @ B_0 + +plt.figure(figsize=(10, 4)) +plt.plot(Y[:, 0], label="true") +plt.plot(Y_pred[:, 0], "--", label="predicted") +plt.xlabel("sample index") +plt.ylabel("first state coordinate") +plt.title("True Signal vs Predicted Signal") +plt.legend() +plt.show() + + +# %% +# The predicted signal is nearly indistinguishable from the true signal, +# indicating that the estimated operator accurately captures the observed +# dynamics. + +# %% +# 6. Recover continuous-time spectral information from the discrete operator +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To recover the continuous generator :math:`A`, we study the spectral +# structure of :math:`B`. We diagonalise :math:`B` as +# +# .. math:: +# +# B=PDP^{-1}, +# +# where +# +# .. math:: +# +# D=\operatorname{diag}(\mu_1,\dots,\mu_n). +# +# The continuous-time eigenvalues are of the form +# +# .. math:: +# +# \lambda_k=-\tau_k+2\pi i\,\omega_k, +# \qquad k\in\{1,2\}, +# +# and the corresponding eigenvalues of :math:`B` are +# +# .. math:: +# +# \mu_k=e^{\Delta t\lambda_k} +# =e^{\Delta t(-\tau_k+2\pi i\omega_k)}. +# +# Since :math:`B=e^{\Delta tA}`, we recover :math:`A` by taking the logarithm: +# +# .. math:: +# +# A=P\,\frac{\log(D)}{\Delta t}\,P^{-1}. + +D_0 = np.log(B_0_spec["eig_val"]) * fs +L_0 = B_0_spec["eig_vec_left"] +R_0 = B_0_spec["eig_vec_right"] + +recovered_freqs = D_0.imag / (2 * np.pi) +mask = recovered_freqs > 0 +recovered_freqs = recovered_freqs[mask] +decay = -D_0.real[mask] +print(f"First mode: frequency: {recovered_freqs[0]:.2f} Hz -- decay: {decay[0]:.2f}") +print(f"Second mode: frequency: {recovered_freqs[1]:.2f} Hz -- decay: {decay[1]:.2f}") + + +# %% +# Applying a rotation in the notebook +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The rotation is introduced through the parameter `theta`. In the +# data-generation step, the trajectory is rotated in the observation plane by +# the 2D rotation matrix +# +# .. math:: +# +# R(\theta)= +# \begin{pmatrix} +# \cos\theta & -\sin\theta \\ +# \sin\theta & \cos\theta +# \end{pmatrix}, +# +# via `traj_0 = traj_0 @ R(theta).T`. +# +# At the operator level, the same transformation is represented by conjugation +# of the reference operator, +# +# .. math:: +# +# A_{\mathrm{rot}} = P(\theta)\,A_{\mathrm{ref}}\,P(\theta)^\top, +# +# where :math:`P(\theta)` is the block rotation acting on both state +# coordinates and their derivatives. + +# [X_1,X_2,X_1_,X_2_] +A_ref = np.array([[0, 0, 1, 0], [0, 0, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 0]]) + +b = 0.1 +c = 1 +A_shift = np.array([[0, 0, 1, 0], [0, 0, 0, 0], [-c, 0, -b, 0], [0, 0, 0, 0]]) + + +def rotation_matrix(theta): + c, s = np.cos(theta), np.sin(theta) + return np.array([[c, -s, 0, 0], [s, c, 0, 0], [0, 0, c, -s], [0, 0, s, c]]) + + +P = rotation_matrix(np.pi / 4) +A_rot = P @ A_ref @ P.T + +A_ref_decomp = np.linalg.eig(A_ref) +A_rot_decomp = np.linalg.eig(A_rot) + + +# %% +# Introduction to SGOT for linear operators +# ----------------------------------------- +# +# To compare two linear operators through their spectral structure, we use the +# SGOT framework introduced in Theorem 1 of [1]. For a non-defective +# finite-rank operator :math:`T \in S_r(\mathcal H)`, the theorem associates a +# discrete spectral measure +# +# .. math:: +# +# \mu(T) \triangleq \sum_{j\in[\ell]} +# \frac{m_j}{m_{\mathrm{tot}}}\,\delta_{(\lambda_j,\mathcal V_j)}, +# +# where :math:`\lambda_j` are the eigenvalues of :math:`T`, :math:`m_j` their +# algebraic multiplicities, and :math:`\mathcal V_j` the corresponding +# eigenspaces. Thus, each spectral component of the operator is represented by +# an atom of the form +# +# .. math:: +# +# (\lambda_j,\mathcal V_j), +# +# combining one eigenvalue with its associated invariant subspace. +# +# Theorem 1 then defines a ground cost between two such atoms by combining a +# spectral discrepancy and a geometric discrepancy: +# +# .. math:: +# +# d_\eta\big((\lambda,\mathcal V),(\lambda',\mathcal V')\big) +# \triangleq +# \eta\,|\lambda-\lambda'| + (1-\eta)\, d_{\mathcal G}(\mathcal V,\mathcal V'), +# +# where :math:`d_{\mathcal G}` denotes the Grassmann distance between +# eigenspaces and :math:`\eta\in(0,1)` balances the contribution of eigenvalues +# and eigenspaces. +# +# The SGOT distance between two operators :math:`T` and :math:`T'` is then the +# Wasserstein distance between their associated spectral measures: +# +# .. math:: +# +# d_S(T,T') = W_{d_\eta,p}\big(\mu(T),\mu(T')\big). +# +# In this way, SGOT compares linear operators by optimally matching their +# spectral atoms, taking into account both the location of eigenvalues and the +# relative geometry of their eigenspaces. + +# %% +# SGOT distance versus rotation angle +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +thetas = np.linspace(0, np.pi, 10) +lst = [] +for i, theta in enumerate(thetas): + traj = generate_data(time, tau_0, freq_0, theta) + Z = augment(traj[:, [0]], 4) + X = Z[:-1] + Y = Z[1:] + B, B_spec = estimator(X, Y, rank=4) + D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] + D = np.log(D) * fs + lst.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.01)) + +plt.figure(figsize=(8, 5)) +plt.plot(thetas, lst) +plt.xlabel("theta") +plt.ylabel("SGOT distance") +plt.title("SGOT distance vs rotation angle") +plt.show() + +# %% +# Comparison across Grassmann metrics +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +thetas = np.linspace(0, np.pi / 2, 10) +lst = [] +for i, theta in enumerate(thetas): + traj = generate_data(time, tau_0, freq_0, theta) + Z = augment(traj[:, [0]], 4) + X = Z[:-1] + Y = Z[1:] + B, B_spec = estimator(X, Y, rank=4) + D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] + D = np.log(D) * fs + lst1 = [] + for name in ["chordal", "martin", "geodesic", "procrustes"]: + lst1.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.9, grassman_metric=name)) + lst.append(lst1) +lst2 = np.array(lst) +plt.figure(figsize=(8, 5)) +for i, name in enumerate(["chordal", "martin", "geodesic", "procrustes"]): + plt.plot(thetas, lst2[:, i], label=name) + +plt.xlabel("theta") +plt.ylabel("SGOT distance") +plt.title("SGOT distance vs rotation angle") +plt.legend() +plt.show() + +# %% +# SGOT distance versus eta +# ~~~~~~~~~~~~~~~~~~~~~~~~ +etas = np.linspace(0.0, 1.0, 21) +methods = ["chordal", "martin", "geodesic", "procrustes"] +scores_eta = [] +theta = theta_0 + +for eta in etas: + freq_1 = np.array([freq_0[0], recovered_freqs[1]]) + traj = generate_data(time, tau_0, freq_1, theta) + Z = augment(traj[:, [0]], 4) + X = Z[:-1] + Y = Z[1:] + + B, B_spec = estimator(X, Y, rank=4) + D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] + D = np.log(D) * fs + + row = [] + for name in methods: + row.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=eta, grassman_metric=name)) + scores_eta.append(row) + +scores_eta = np.array(scores_eta) +plt.figure(figsize=(8, 5)) +for i, name in enumerate(methods): + plt.plot(etas, scores_eta[:, i], label=name) + +plt.xlabel("eta") +plt.ylabel("SGOT distance") +plt.title("SGOT distance vs eta") +plt.legend() +plt.show() + +# %% +# SGOT distance versus decay +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +decays = np.linspace(0.1, 3.0, 20) # adjust range as needed +methods = ["chordal", "martin", "geodesic", "procrustes"] +scores_decay = [] +theta = theta_0 + +for tau in decays: + freq_1 = np.array([freq_0[0], recovered_freqs[1]]) + tau_1 = np.array([tau, tau]) # or whatever structure your generator expects + + traj = generate_data(time, tau_1, freq_1, theta) + Z = augment(traj[:, [0]], 4) + X = Z[:-1] + Y = Z[1:] + + B, B_spec = estimator(X, Y, rank=4) + D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] + D = np.log(D) * fs + + row = [] + for name in methods: + row.append( + sgot_metric( + D_0, + R_0, + L_0, + D, + R, + L, + eta=0.9, # keep eta fixed here + grassman_metric=name, + ) + ) + scores_decay.append(row) + +scores_decay = np.array(scores_decay) +plt.figure(figsize=(8, 5)) +for i, name in enumerate(methods): + plt.plot(decays, scores_decay[:, i], label=name) + +plt.xlabel("decay") +plt.ylabel("SGOT distance") +plt.title("SGOT distance vs decay") +plt.legend() +plt.show() diff --git a/ot/backend.py b/ot/backend.py index d7fed4e2f..9c6a25150 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -622,6 +622,46 @@ def clip(self, a, a_min=None, a_max=None): """ raise NotImplementedError() + def real(self, a): + """ + Return the real part of the tensor element-wise. + + This function follows the api from :any:`numpy.real` + + See: https://numpy.org/doc/stable/reference/generated/numpy.real.html + """ + raise NotImplementedError() + + def imag(self, a): + """ + Return the imaginary part of the tensor element-wise. + + This function follows the api from :any:`numpy.imag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html + """ + raise NotImplementedError() + + def conj(self, a): + """ + Return the complex conjugate, element-wise. + + This function follows the api from :any:`numpy.conj` + + See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html + """ + raise NotImplementedError() + + def arccos(self, a): + """ + Trigonometric inverse cosine, element-wise. + + This function follows the api from :any:`numpy.arccos` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html + """ + raise NotImplementedError() + def repeat(self, a, repeats, axis=None): r""" Repeats elements of a tensor. @@ -1193,7 +1233,7 @@ def _from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return a.astype(type_as.dtype) + return np.asarray(a, dtype=type_as.dtype) def set_gradients(self, val, inputs, grads): # No gradients for numpy @@ -1313,6 +1353,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return np.clip(a, a_min, a_max) + def real(self, a): + return np.real(a) + + def imag(self, a): + return np.imag(a) + + def conj(self, a): + return np.conj(a) + + def arccos(self, a): + return np.arccos(a) + def repeat(self, a, repeats, axis=None): return np.repeat(a, repeats, axis) @@ -1604,7 +1656,7 @@ def _from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) + return self._change_device(jnp.asarray(a, dtype=type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1730,6 +1782,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return jnp.clip(a, a_min, a_max) + def real(self, a): + return jnp.real(a) + + def imag(self, a): + return jnp.imag(a) + + def conj(self, a): + return jnp.conj(a) + + def arccos(self, a): + return jnp.arccos(a) + def repeat(self, a, repeats, axis=None): return jnp.repeat(a, repeats, axis) @@ -1803,7 +1867,9 @@ def randperm(self, size, type_as=None): if not isinstance(size, int): raise ValueError("size must be an integer") if type_as is not None: - return jax.random.permutation(subkey, size).astype(type_as.dtype) + return jnp.asarray( + jax.random.permutation(subkey, size), dtype=type_as.dtype + ) else: return jax.random.permutation(subkey, size) @@ -2056,7 +2122,7 @@ def backward(ctx, g): def _to_numpy(self, a): if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): return np.array(a) - return a.cpu().detach().numpy() + return a.cpu().detach().resolve_conj().numpy() def _from_numpy(self, a, type_as=None): if ( @@ -2227,6 +2293,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return torch.clamp(a, a_min, a_max) + def real(self, a): + return torch.real(a) + + def imag(self, a): + return torch.imag(a) + + def conj(self, a): + return torch.conj(a) + + def arccos(self, a): + return torch.acos(a) + def repeat(self, a, repeats, axis=None): return torch.repeat_interleave(a, repeats, dim=axis) @@ -2728,6 +2806,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return cp.clip(a, a_min, a_max) + def real(self, a): + return cp.real(a) + + def imag(self, a): + return cp.imag(a) + + def conj(self, a): + return cp.conj(a) + + def arccos(self, a): + return cp.arccos(a) + def repeat(self, a, repeats, axis=None): return cp.repeat(a, repeats, axis) @@ -2819,7 +2909,7 @@ def randperm(self, size, type_as=None): return self.rng_.permutation(size) else: with cp.cuda.Device(type_as.device): - return self.rng_.permutation(size).astype(type_as.dtype) + return cp.asarray(self.rng_.permutation(size), dtype=type_as.dtype) def coo_matrix(self, data, rows, cols, shape=None, type_as=None): data = self.from_numpy(data) @@ -3162,6 +3252,18 @@ def outer(self, a, b): def clip(self, a, a_min=None, a_max=None): return tnp.clip(a, a_min, a_max) + def real(self, a): + return tnp.real(a) + + def imag(self, a): + return tnp.imag(a) + + def conj(self, a): + return tnp.conj(a) + + def arccos(self, a): + return tnp.arccos(a) + def repeat(self, a, repeats, axis=None): return tnp.repeat(a, repeats, axis) diff --git a/ot/sgot.py b/ot/sgot.py new file mode 100644 index 000000000..5248432be --- /dev/null +++ b/ot/sgot.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- +""" +Spectral-Grassmann optimal transport for linear operators. + +This module implements the Spectral-Grassmann Wasserstein framework for +comparing dynamical systems via their learned operator representations. + +It provides tools to extract spectral "atoms" (eigenvalues and associated +eigenspaces) from linear operators and to compute an optimal transport metric +that combines a spectral term on eigenvalues with a Grassmannian term on +eigenspaces. +""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License + +import ot +from ot.backend import get_backend + +##################################################################################################################################### +##################################################################################################################################### +### NORMALISATION AND OPERATOR ATOMS ### +##################################################################################################################################### +##################################################################################################################################### + + +def eigenvalue_cost_matrix(Ds, Dt, q=1, eigen_scaling=None, nx=None): + """Compute pairwise eigenvalue distances for source and target domains. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Source eigenvalues. + Dt: array-like, shape (n_t,) + Target eigenvalues. + eigen_scaling: None or array-like of length 2, optional + Scaling (real_scale, imag_scale) applied to eigenvalues before computing + distances. If None, defaults to (1.0, 1.0). Accepts tuple/list or + array/tensor with two entries. + + Returns + ---------- + C: np.ndarray, shape (n_s, n_t) + Eigenvalue cost matrix. + """ + if nx is None: + nx = get_backend(Ds, Dt) + + if eigen_scaling is None: + real_scale, imag_scale = 1.0, 1.0 + else: + if isinstance(eigen_scaling, (tuple, list)): + real_scale, imag_scale = eigen_scaling + else: + real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] + + Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale + Dtn = nx.real(Dt) * real_scale + 1j * nx.imag(Dt) * imag_scale + C_real = nx.real(Dsn[:, None] - nx.real(Dtn)[None, :]) + C_real = C_real**2 + C_imag = nx.imag(Dsn)[:, None] - nx.imag(Dtn)[None, :] + C_imag = C_imag**2 + prod = C_real + C_imag + return prod ** (q / 2) + + +def _normalize_columns(A, nx, eps=1e-12): + """Normalize the columns of an array with a backend-aware norm. + + Parameters + ---------- + A: array-like, shape (d, n) + Input array whose columns are normalized. + nx: module + Backend (NumPy-compatible) used for math operations. + eps: float, optional + Minimum norm value to avoid division by zero, default 1e-12. + + Returns + ---------- + A_norm: array-like, shape (d, n) + Column-normalized array. + """ + nrm = nx.norm(A, axis=0, keepdims=True) + nrm = nx.real(nrm) # norm is real; avoid complex dtype for maximum (e.g. torch) + nrm = nx.maximum(nrm, eps) + return A / nrm + + +def _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): + """Compute the normalized inner-product delta matrix for eigenspaces. + + Parameters + ---------- + Rs: array-like, shape (L, n_s) + Source right eigenvectors. + Ls: array-like, shape (L, n_s) + Source left eigenvectors. + Rt: array-like, shape (L, n_t) + Target right eigenvectors. + Lt: array-like, shape (L, n_t) + Target left eigenvectors. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum norm value used in normalization, default 1e-12. + + Returns + ---------- + delta: array-like, shape (n_s, n_t) + Delta matrix with entries in [0, 1]. + """ + if nx is None: + nx = get_backend(Rs, Ls, Rt, Lt) + + Rsn = _normalize_columns(Rs, nx=nx, eps=eps) + Lsn = _normalize_columns(Ls, nx=nx, eps=eps) + Rtn = _normalize_columns(Rt, nx=nx, eps=eps) + Ltn = _normalize_columns(Lt, nx=nx, eps=eps) + + Cr = nx.dot(nx.conj(Rsn).T, Rtn) + Cl = nx.dot(nx.conj(Lsn).T, Ltn) + + delta = nx.abs(Cr * Cl) + delta = nx.clip(delta, 0.0, 1.0) + return delta + + +def _grassmann_distance_squared( + delta, grassman_metric="chordal", q=1, nx=None, eps=1e-12 +): + """Compute Grassmannian distances from delta similarities. + + Parameters + ---------- + delta: array-like + Similarity values in [0, 1]. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + q: int or float, optional + Exponent applied to the Grassmann distance, in the same spirit as the + eigenvalue cost exponent. Default is 1. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum value used for numerical stability in the Martin metric. + + Returns + ------- + dist_q: array-like + Grassmannian distances raised to the power q. + """ + if nx is None: + nx = get_backend(delta) + + if nx.any(delta < 0) or nx.any(delta > 1.0): + raise ValueError( + "delta must be in [0, 1]; found values outside this range " + f"(min={nx.min(delta)}, max={nx.max(delta)})" + ) + + delta = nx.clip(delta, 0.0, 1.0) + + if grassman_metric == "geodesic": + dist2 = nx.arccos(delta) ** 2 + elif grassman_metric == "chordal": + dist2 = 1.0 - delta**2 + elif grassman_metric == "procrustes": + dist2 = 2.0 * (1.0 - delta) + elif grassman_metric == "martin": + delta2 = nx.maximum(delta**2, eps) + dist2 = -nx.log(delta2) + else: + raise ValueError(f"Unknown grassman_metric: {grassman_metric}") + + return nx.real(dist2) ** (q / 2.0) + + +##################################################################################################################################### +##################################################################################################################################### +### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### +##################################################################################################################################### +##################################################################################################################################### +def sgot_cost_matrix( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=0.5, + p=2, + q=1, + grassman_metric="chordal", + eigen_scaling=None, + nx=None, + eps=1e-12, +): + r"""Compute the SGOT cost matrix between two spectral decompositions. + + This returns the discrete ground cost matrix used in the SGOT optimal transport + objective. Each spectral atom is :math:`z_i=(\lambda_i, V_i)` where + :math:`\lambda_i \in \mathbb{C}` is an eigenvalue and :math:`V_i` is the + associated (bi-orthogonal) eigenspace point. + + .. math:: + C_2(i,j) \;=\; \eta\,C_\lambda(i,j) \;+\; (1-\eta)\,C_G(i,j), + + with spectral term + + .. math:: + C_\lambda(i,j) \;=\; \big|\lambda_i - \lambda'_j\big|^{q}, + + and Grassmann term computed from a similarity score :math:`\delta_{ij}\in[0,1]` + built from left/right eigenvectors + + .. math:: + \delta_{ij} \;=\; \left|\langle r_i, r'_j\rangle\,\langle \ell_i, \ell'_j\rangle\right|. + + Depending on ``grassman_metric``, the Grassmann contribution is: + + - ``"chordal"``: + .. math:: + C_G(i,j) \;=\; 1 - \delta_{ij}^2 + - ``"geodesic"``: + .. math:: + C_G(i,j) \;=\; \arccos(\delta_{ij})^2 + - ``"procrustes"``: + .. math:: + C_G(i,j) \;=\; 2(1-\delta_{ij}) + - ``"martin"``: + .. math:: + C_G(i,j) \;=\; -\log\!\left(\max(\delta_{ij}^2,\varepsilon)\right) + + Finally, we return a matrix suited for a :math:`p`-Wasserstein objective by + treating :math:`C_2 \approx d^2` and outputting + + .. math:: + C(i,j) \;=\; \big(\operatorname{Re}(C_2(i,j))\big)^{p/2}. + + Parameters + ---------- + Ds: array-like, shape (n_s,) + Eigenvalues of operator T1. + Rs: array-like, shape (L, n_s) + Right eigenvectors of operator T1. + Ls: array-like, shape (L, n_s) + Left eigenvectors of operator T1. + Dt: array-like, shape (n_t,) + Eigenvalues of operator T2. + Rt: array-like, shape (L, n_t) + Right eigenvectors of operator T2. + Lt: array-like, shape (L, n_t) + Left eigenvectors of operator T2. + eta: float, optional + Weighting between spectral and Grassmann terms, default 0.5. + p: int, optional + Exponent defining the OT ground cost. The returned cost is :math:`d^p` with + :math:`d^2 \approx C_2`. Default is 2. + q: int, optional + Exponent applied to the eigenvalue distance in the spectral term. + Default is 1. + grassman_metric: str, optional + Metric type: "geodesic", "chordal", "procrustes", or "martin". + eigen_scaling: None or array-like of length 2, optional + Scaling ``(real_scale, imag_scale)`` applied to eigenvalues before computing + :math:`C_\lambda`. If provided, eigenvalues are transformed as + :math:`\lambda \mapsto \alpha\operatorname{Re}(\lambda) + i\,\beta\operatorname{Im}(\lambda)`. + If None, defaults to ``(1.0, 1.0)``. Accepts tuple/list or array/tensor with + two entries. + nx: module, optional + Backend (NumPy-compatible). If None, inferred from inputs. + eps: float, optional + Minimum value used for numerical stability in Grassmann distances and + Martin metric. Default is 1e-12. + + Returns + ---------- + C: array-like, shape (n_s, n_t) + SGOT cost matrix :math:`C = d^p`. + + References + ---------- + Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). + """ + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + + _validate_sgot_inputs(Ds, Rs, Ls, Dt, Rt, Lt) + + C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) + delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) + C_grass = _grassmann_distance_squared( + delta, + grassman_metric=grassman_metric, + q=q, + nx=nx, + eps=eps, + ) + + C2 = eta * C_lambda + (1.0 - eta) * C_grass + C = C2 ** (p / 2.0) + + return C + + +def _validate_sgot_inputs(Ds, Rs, Ls, Dt, Rt, Lt): + """Validate shapes of spectral atoms for SGOT cost/metric.""" + Ds_shape = getattr(Ds, "shape", None) + Dt_shape = getattr(Dt, "shape", None) + Ds_ndim = getattr(Ds, "ndim", None) + Dt_ndim = getattr(Dt, "ndim", None) + + if Ds_ndim != 1 or Dt_ndim != 1: + raise ValueError( + "SGOT expects Ds and Dt to be 1D (n,), " + f"got Ds shape {Ds_shape} and Dt shape {Dt_shape}" + ) + + if Rs.shape != Ls.shape or Rt.shape != Lt.shape: + raise ValueError( + "Right/left eigenvector shapes must match; got " + f"(Rs,Ls)=({Rs.shape},{Ls.shape}), (Rt,Lt)=({Rt.shape},{Lt.shape})" + ) + + if Rs.shape[1] != Ds.shape[0] or Rt.shape[1] != Dt.shape[0]: + raise ValueError( + "Eigenvector columns must match eigenvalues: " + f"Rs {Rs.shape[1]} vs Ds {Ds.shape[0]}, " + f"Rt {Rt.shape[1]} vs Dt {Dt.shape[0]}" + ) + + +def sgot_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=0.5, + p=2, + q=1, + r=2, + grassman_metric="chordal", + eigen_scaling=None, + Ws=None, + Wt=None, + nx=None, + eps=1e-12, +): + r"""Compute the SGOT metric between two spectral decompositions. + + This function computes a discrete optimal transport problem between two measures + over spectral atoms :math:`z_i=(\lambda_i, V_i)` and :math:`z'_j=(\lambda'_j, V'_j)`. + Using the ground cost matrix :math:`C = d^p` returned by :func:`sgot_cost_matrix`, + we solve: + + .. math:: + P^\star \in \arg\min_{P\in\Pi(W_s, W_t)} \langle C, P\rangle, + + where :math:`C(i,j) = d(i,j)^p` and :math:`d(i,j)` is the SGOT ground distance + combining spectral and Grassmann terms with exponent :math:`q`: + + .. math:: + d(i,j)^2 + \;=\; \eta\,\big|\lambda_i - \lambda'_j\big|^{q} + \;+\; (1-\eta)\,d_G(i,j)^{q}, + + and :math:`d_G(i,j)` is the Grassmann distance associated with the chosen + ``grassman_metric``. + + From the optimal plan :math:`P^\star`, we first form the :math:`p`-Wasserstein + objective: + + .. math:: + \mathrm{obj} + \;=\; + \left(\sum_{i,j} C(i,j)\,P^\star_{ij}\right)^{1/p}, + + and then apply an outer root :math:`r`: + + .. math:: + \mathrm{SGOT} + \;=\; + \mathrm{obj}^{1/r}. + + In summary: + + - :math:`q` controls how strongly spectral and Grassmann distances are curved + (via :math:`|\lambda_i - \lambda'_j|^{q}` and :math:`d_G(i,j)^{q}`), + - :math:`p` is the exponent used in the OT ground cost and the inner + Wasserstein root, + - :math:`r` is an additional outer root applied to the Wasserstein objective. + """ + if nx is None: + nx = get_backend(Ds, Rs, Ls, Dt, Rt, Lt) + + _validate_sgot_inputs(Ds, Rs, Ls, Dt, Rt, Lt) + + C = sgot_cost_matrix( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + grassman_metric=grassman_metric, + eigen_scaling=eigen_scaling, + nx=nx, + eps=eps, + ) + + if Ws is None: + Ws = nx.ones((C.shape[0],), type_as=C) / float(C.shape[0]) + if Wt is None: + Wt = nx.ones((C.shape[1],), type_as=C) / float(C.shape[1]) + + Ws = Ws / nx.sum(Ws) + Wt = Wt / nx.sum(Wt) + + obj = ot.emd2(Ws, Wt, nx.real(C)) + obj = obj ** (1.0 / p) + return obj ** (1.0 / float(r)) diff --git a/test/test_backend.py b/test/test_backend.py index cd6a85762..fe6af9c67 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -339,6 +339,9 @@ def test_func_backends(nx): sp_col = np.array([0, 3, 1, 2, 2]) sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) + M_complex = M + 1j * rnd.randn(10, 3) + v_acos = np.clip(v, -0.99, 0.99) + lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: @@ -723,6 +726,24 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("atan2") + M_complex_b = nx.from_numpy(M_complex) + A = nx.real(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("real") + + A = nx.imag(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("imag") + + A = nx.conj(M_complex_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("conj") + + v_acos_b = nx.from_numpy(v_acos) + A = nx.arccos(v_acos_b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("arccos") + A = nx.transpose(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") diff --git a/test/test_sgot.py b/test/test_sgot.py new file mode 100644 index 000000000..b9280d116 --- /dev/null +++ b/test/test_sgot.py @@ -0,0 +1,264 @@ +"""Tests for ot.sgot module""" + +# Author: Sienna O'Shea +# Thibaut Germain +# License: MIT License + +import numpy as np +import pytest + +from ot.sgot import ( + eigenvalue_cost_matrix, + _delta_matrix_1d, + _grassmann_distance_squared, + sgot_cost_matrix, + sgot_metric, +) + + +def random_atoms(d=8, r=4, seed=42): + """Deterministic complex atoms for given d, r.""" + + def _rand_complex(shape, seed_): + rng = np.random.RandomState(seed_) + real = rng.randn(*shape) + imag = rng.randn(*shape) + return real + 1j * imag + + Ds = _rand_complex((r,), seed + 0) + Rs = _rand_complex((d, r), seed + 1) + Ls = _rand_complex((d, r), seed + 2) + Dt = _rand_complex((r,), seed + 3) + Rt = _rand_complex((d, r), seed + 4) + Lt = _rand_complex((d, r), seed + 5) + + return Ds, Rs, Ls, Dt, Rt, Lt + + +# --------------------------------------------------------------------- +# DATA / SAMPLING TESTS +# --------------------------------------------------------------------- + + +def test_random_d_r(nx): + """Sample d and r uniformly and run sgot_cost_matrix (and sgot_metric when available) with those shapes.""" + rng = np.random.RandomState(0) + d_min, d_max = 4, 12 + r_min, r_max = 2, 6 + for _ in range(5): + d = int(rng.randint(d_min, d_max + 1)) + r = int(rng.randint(r_min, r_max + 1)) + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms(d=d, r=r) + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(C_np.shape, (r, r)) + assert np.all(np.isfinite(C_np)) and np.all(C_np >= 0) + try: + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) and dist_np >= 0 + except TypeError: + pytest.skip("sgot_metric() unavailable (emd_c signature mismatch)") + + +# --------------------------------------------------------------------- +# DELTA MATRIX TESTS +# --------------------------------------------------------------------- + + +def test_eigenvalue_cost_matrix_simple(): + Ds = np.array([0.0, 1.0]) + Dt = np.array([0.0, 2.0]) + C = eigenvalue_cost_matrix(Ds, Dt, q=2) + expected = np.array([[0.0, 4.0], [1.0, 1.0]]) + np.testing.assert_allclose(C, expected) + + +def test_delta_matrix_1d_identity(): + r = 4 + I = np.eye(r, dtype=complex) + delta = _delta_matrix_1d(I, I, I, I) + np.testing.assert_allclose(delta, np.eye(r), atol=1e-6) + + +def test_delta_matrix_1d_swap_invariance(): + d, r = 6, 3 + _, R, _, _, _, _ = random_atoms(d=d, r=r) + L = R.copy() + delta1 = _delta_matrix_1d(R, L, R, L) + delta2 = _delta_matrix_1d(L, R, L, R) + np.testing.assert_allclose(delta1, delta2, atol=1e-6) + + +# --------------------------------------------------------------------- +# GRASSMANN DISTANCE TESTS +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_grassmann_zero_distance(grassman_metric, nx): + delta = nx.from_numpy(np.ones((3, 3))) + dist2 = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) + dist2_np = nx.to_numpy(dist2) + np.testing.assert_allclose(dist2_np, 0.0, atol=1e-6) + + +def test_grassmann_distance_invalid_name(): + delta = np.ones((2, 2)) + with pytest.raises(ValueError): + _grassmann_distance_squared(delta, grassman_metric="cordal") + + +# --------------------------------------------------------------------- +# COST TESTS +# --------------------------------------------------------------------- + + +def test_cost_self_zero(nx): + """(D_S R_S L_S D_S): diagonal of sgot_cost_matrix matrix (same atom to same atom) should be near zero.""" + Ds, Rs, Ls, _, _, _ = random_atoms() + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2) + C_np = nx.to_numpy(C) + np.testing.assert_allclose(np.diag(C_np), np.zeros(C_np.shape[0]), atol=1e-6) + np.testing.assert_allclose(C_np, C_np.T, atol=1e-6) + + +def test_grassmann_cost_reference(nx): + """Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose).""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + eta, p, q = 0.5, 2, 1 + C1 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + C2 = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, eta=eta, p=p, q=q) + np.testing.assert_allclose(nx.to_numpy(C1), nx.to_numpy(C2), atol=1e-6) + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_grassmann_cost_basic_properties(grassman_metric, nx): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + C = sgot_cost_matrix( + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b, grassman_metric=grassman_metric + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +def test_sgot_cost_input_validation(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + + with pytest.raises(ValueError): + sgot_cost_matrix(Ds.reshape(-1, 1), Rs, Ls, Dt, Rt, Lt) + + with pytest.raises(ValueError): + sgot_cost_matrix(Ds, Rs[:, :-1], Ls, Dt, Rt, Lt) + + +# --------------------------------------------------------------------- +# METRIC TESTS +# --------------------------------------------------------------------- + + +def test_sgot_metric_self_zero(nx): + Ds, Rs, Ls, _, _, _ = random_atoms() + Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls) + dist = sgot_metric(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2, nx=nx) + dist_np = nx.to_numpy(dist) + assert np.isfinite(dist_np) + assert abs(float(dist_np)) < 2e-2 + + +def test_sgot_metric_symmetry(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + d1 = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt) + d2 = sgot_metric(Dt, Rt, Lt, Ds, Rs, Ls) + np.testing.assert_allclose(d1, d2, atol=1e-6) + + +def test_sgot_metric_with_weights(): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + r = Ds.shape[0] + + rng = np.random.RandomState(1) + Ws = rng.rand(r) + Ws = Ws / np.sum(Ws) + + Wt = rng.rand(r) + Wt = Wt / np.sum(Wt) + + dist = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt, Ws=Ws, Wt=Wt) + assert np.isfinite(dist) + + +# --------------------------------------------------------------------- +# HYPERPARAMETER SWEEP TEST +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "eta, p, q, grassman_metric", + [ + (0.5, 1, 1, "geodesic"), + (0.5, 2, 1, "chordal"), + (0.3, 2, 2, "procrustes"), + (0.7, 1, 2, "martin"), + ], +) +def test_hyperparameter_sweep_cost(nx, eta, p, q, grassman_metric): + """Sweep over a set of fixed HPs and run cost().""" + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt) + + C = sgot_cost_matrix( + Ds_b, + Rs_b, + Ls_b, + Dt_b, + Rt_b, + Lt_b, + eta=eta, + p=p, + q=q, + grassman_metric=grassman_metric, + ) + C_np = nx.to_numpy(C) + assert C_np.shape == (Ds.shape[0], Dt.shape[0]) + assert np.all(np.isfinite(C_np)) + assert np.all(C_np >= 0) + + +@pytest.mark.parametrize( + "grassman_metric", ["geodesic", "chordal", "procrustes", "martin"] +) +def test_hyperparameter_sweep(grassman_metric): + Ds, Rs, Ls, Dt, Rt, Lt = random_atoms() + rng = np.random.RandomState(3) + eta = rng.uniform(0.0, 1.0) + p = rng.choice([1, 2]) + q = rng.choice([1, 2]) + r = rng.choice([1, 2]) + + dist = sgot_metric( + Ds, + Rs, + Ls, + Dt, + Rt, + Lt, + eta=eta, + p=p, + q=q, + r=r, + grassman_metric=grassman_metric, + ) + + assert np.isfinite(dist) + assert dist >= 0