|
2 | 2 | from warnings import warn |
3 | 3 | import numpy as np |
4 | 4 | from scipy import sparse |
| 5 | +import pywt |
5 | 6 |
|
6 | 7 | from pynumdiff.utils import utility |
7 | 8 |
|
@@ -133,3 +134,111 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0): |
133 | 134 | dxdt_hat_flattened = drbfdt @ alpha |
134 | 135 |
|
135 | 136 | return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis) |
| 137 | + |
| 138 | + |
| 139 | +def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='periodization'): |
| 140 | + """Smooth and differentiate noisy data in a wavelet basis. |
| 141 | +
|
| 142 | + Three steps: (1) decompose x with the DWT and soft-threshold the detail |
| 143 | + coefficients to denoise (Donoho-Johnstone universal threshold), reconstructing |
| 144 | + a smoothed x_hat; (2) extend x_hat antisymmetrically so the periodic derivative |
| 145 | + operator stays accurate at the edges; (3) recover the wavelet scaling |
| 146 | + coefficients of x_hat and apply the analytic derivative of the wavelet basis. |
| 147 | +
|
| 148 | + The derivative differentiates the basis functions themselves rather than |
| 149 | + finite-differencing the signal. PyWavelets treats the samples as finest-level |
| 150 | + scaling coefficients, so x_hat is the interpolant x(t) = sum_n a_n phi(t/dt - n) |
| 151 | + for the scaling function phi. Sampling x and its analytic derivative on the grid |
| 152 | + gives two convolutions against phi and phi' evaluated at *integers*, |
| 153 | +
|
| 154 | + x_hat = Phi @ a and x' = Phi_prime @ a, |
| 155 | +
|
| 156 | + so x' = Phi_prime @ Phi^-1 @ x_hat, exact for signals the basis can represent. |
| 157 | + The integer samples phi(p), phi'(p) are the eigenvalue-1 and eigenvalue-1/2 |
| 158 | + eigenvectors of the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k) |
| 159 | + (the "connection coefficients"), normalized to reproduce constants and ramps. |
| 160 | +
|
| 161 | + Because the DWT requires uniform spacing, this method only accepts a scalar |
| 162 | + time step dt (not a vector of sample times). For non-uniformly sampled data, |
| 163 | + use :func:`rbfdiff` or :func:`splinediff` instead. |
| 164 | +
|
| 165 | + :param np.array x: data to differentiate. May be multidimensional; see :code:`axis`. |
| 166 | + :param float dt: uniform time step between samples. |
| 167 | + :param str wavelet: PyWavelets wavelet name. Must have a differentiable scaling |
| 168 | + function, so smoother wavelets give better derivatives: 'db8' (default) and |
| 169 | + 'sym8' are best for noisy data; 'db4', 'sym4', and 'coif2' also work well. |
| 170 | + :param int level: decomposition depth. None (default) resolves to |
| 171 | + min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals. |
| 172 | + :param float threshold: soft-thresholding scale factor in [0, inf). |
| 173 | + :param int axis: axis along which to differentiate (default 0). |
| 174 | + :param str mode: PyWavelets signal extension mode for the denoising transform. |
| 175 | + 'periodization' keeps coefficient arrays compact. The derivative operator is |
| 176 | + periodic, so x_hat is antisymmetrically extended before it is applied (see below). |
| 177 | + :return: - **x_hat** (np.array) -- estimated (smoothed) x |
| 178 | + - **dxdt_hat** (np.array) -- estimated derivative of x |
| 179 | + """ |
| 180 | + if not np.isscalar(dt): |
| 181 | + raise ValueError("`dt` must be a scalar. The DWT requires uniformly sampled data. " |
| 182 | + "For variable step sizes, use rbfdiff or splinediff instead.") |
| 183 | + |
| 184 | + # The Haar scaling function is a step, so it has no pointwise derivative and the |
| 185 | + # connection-coefficient operator below is undefined for it. Haar/db1 is the only |
| 186 | + # orthonormal wavelet with a 2-tap filter, so dec_len identifies it. |
| 187 | + if pywt.Wavelet(wavelet).dec_len == 2: |
| 188 | + raise ValueError("The Haar/db1 wavelet has a discontinuous (piecewise-constant) scaling " |
| 189 | + "function with no derivative, so it cannot be used to differentiate. Pick a smoother " |
| 190 | + "wavelet such as 'db4', 'sym4', or 'coif2'.") |
| 191 | + |
| 192 | + N = x.shape[axis] |
| 193 | + x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) # differentiation axis to front |
| 194 | + shape = x_work.shape # remember it to restore the input's dimensionality |
| 195 | + x_flat = x_work.reshape(N, -1) # rest of the dims flattened into columns |
| 196 | + Ne = 3 * N - 2 # length after the antisymmetric extension in step 2 |
| 197 | + |
| 198 | + # Build the wavelet-basis derivative operator (depends only on the grid and wavelet). |
| 199 | + # Sampling the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k) at integers makes |
| 200 | + # phi(p) the eigenvalue-1 and phi'(p) the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 h_{2p-q}. |
| 201 | + h = np.array(pywt.Wavelet(wavelet).rec_lo); h = h / h.sum() * np.sqrt(2) # refinement filter, integral of phi = 1 |
| 202 | + L = len(h); p = np.arange(L) # phi is supported on the integers [0, L-1] |
| 203 | + shift = 2 * p[:, None] - p[None, :] |
| 204 | + T = np.where((shift >= 0) & (shift < L), np.sqrt(2) * h[np.clip(shift, 0, L - 1)], 0.0) |
| 205 | + evals, evecs = np.linalg.eig(T) |
| 206 | + phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]); phi /= phi.sum() # sum_p phi(p) = 1 |
| 207 | + dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]); dphi /= np.dot(p, dphi)*-1 # sum_p p*phi'(p) = -1 |
| 208 | + # Phi and Phi_prime hold circulant samples of phi and phi'/dt on the extended grid; both |
| 209 | + # share a common shift that cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic. |
| 210 | + rows, cols, phi_vals, dphi_vals = [], [], [], [] |
| 211 | + m = np.arange(Ne) |
| 212 | + for offset, phi_p, dphi_p in zip(p, phi, dphi / dt): |
| 213 | + rows.extend(m); cols.extend((m - offset) % Ne); phi_vals.extend([phi_p]*Ne); dphi_vals.extend([dphi_p]*Ne) |
| 214 | + Phi = sparse.csr_matrix((phi_vals, (rows, cols)), shape=(Ne, Ne)).tocsc() # to invert |
| 215 | + Phi_prime = sparse.csr_matrix((dphi_vals, (rows, cols)), shape=(Ne, Ne)) # to apply |
| 216 | + |
| 217 | + if level is None: |
| 218 | + level = min(pywt.dwt_max_level(N, wavelet), 5) |
| 219 | + |
| 220 | + # 1. Denoise: DWT all columns at once, then soft-threshold the detail bands. The |
| 221 | + # noise level is estimated robustly per column from the finest details (coeffs[-1]). |
| 222 | + coeffs = pywt.wavedec(x_flat, wavelet, level=level, mode=mode, axis=0) |
| 223 | + sigma = np.maximum(np.median(np.abs(coeffs[-1]), axis=0) / 0.6745, 1e-10) |
| 224 | + thresh = threshold * sigma * np.sqrt(2 * np.log(N)) |
| 225 | + coeffs = [coeffs[0]] + [pywt.threshold(c, thresh[np.newaxis, :], mode='soft') for c in coeffs[1:]] |
| 226 | + x_hat = pywt.waverec(coeffs, wavelet, mode=mode, axis=0)[:N] |
| 227 | + |
| 228 | + # 2. The derivative operator is periodic, but x_hat usually isn't. Extend it |
| 229 | + # antisymmetrically (reflect through each endpoint: x[-1-k] -> 2*x[0]-x[1+k]) so the |
| 230 | + # periodic wrap is continuous in both value and slope, which keeps the derivative |
| 231 | + # accurate at the edges instead of spiking there. This is the odd-symmetry analog of |
| 232 | + # spectraldiff's even extension; a ramp extends to a ramp, so slopes survive exactly. |
| 233 | + left = 2 * x_hat[0] - x_hat[1:][::-1] |
| 234 | + right = 2 * x_hat[-1] - x_hat[:-1][::-1] |
| 235 | + x_ext = np.concatenate([left, x_hat, right], axis=0) # length 3N-2, original at [N-1:2N-1] |
| 236 | + |
| 237 | + # 3. Differentiate the basis: recover the scaling coefficients a = Phi^-1 @ x_ext, then |
| 238 | + # apply the analytic basis derivative dxdt = Phi_prime @ a, and crop back to the original. |
| 239 | + a = sparse.linalg.spsolve(Phi, x_ext) |
| 240 | + dxdt_flat = (Phi_prime @ a.reshape(Ne, -1))[N - 1:2 * N - 1] |
| 241 | + |
| 242 | + x_hat = np.moveaxis(x_hat.reshape(shape), 0, axis) |
| 243 | + dxdt_hat = np.moveaxis(dxdt_flat.reshape(shape), 0, axis) |
| 244 | + return x_hat, dxdt_hat |
0 commit comments