Skip to content

Commit 90a137d

Browse files
authored
Merge pull request #202 from mariaprot/waveletdiff
Add waveletdiff: wavelet-based denoising differentiator
2 parents b38199f + 8e8d46e commit 90a137d

4 files changed

Lines changed: 122 additions & 3 deletions

File tree

pynumdiff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
from .finite_difference import finitediff, first_order, second_order, fourth_order
1616
from .smooth_finite_difference import kerneldiff, meandiff, mediandiff, gaussiandiff, friedrichsdiff, butterdiff
1717
from .polynomial_fit import splinediff, polydiff, savgoldiff
18-
from .basis_fit import spectraldiff, rbfdiff
18+
from .basis_fit import spectraldiff, rbfdiff, waveletdiff
1919
from .total_variation_regularization import iterative_velocity
2020
from .kalman_smooth import kalman_filter, rts_smooth, rtsdiff, constant_velocity, constant_acceleration, constant_jerk

pynumdiff/basis_fit.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from warnings import warn
33
import numpy as np
44
from scipy import sparse
5+
import pywt
56

67
from pynumdiff.utils import utility
78

@@ -133,3 +134,111 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
133134
dxdt_hat_flattened = drbfdt @ alpha
134135

135136
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)
137+
138+
139+
def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='periodization'):
140+
"""Smooth and differentiate noisy data in a wavelet basis.
141+
142+
Three steps: (1) decompose x with the DWT and soft-threshold the detail
143+
coefficients to denoise (Donoho-Johnstone universal threshold), reconstructing
144+
a smoothed x_hat; (2) extend x_hat antisymmetrically so the periodic derivative
145+
operator stays accurate at the edges; (3) recover the wavelet scaling
146+
coefficients of x_hat and apply the analytic derivative of the wavelet basis.
147+
148+
The derivative differentiates the basis functions themselves rather than
149+
finite-differencing the signal. PyWavelets treats the samples as finest-level
150+
scaling coefficients, so x_hat is the interpolant x(t) = sum_n a_n phi(t/dt - n)
151+
for the scaling function phi. Sampling x and its analytic derivative on the grid
152+
gives two convolutions against phi and phi' evaluated at *integers*,
153+
154+
x_hat = Phi @ a and x' = Phi_prime @ a,
155+
156+
so x' = Phi_prime @ Phi^-1 @ x_hat, exact for signals the basis can represent.
157+
The integer samples phi(p), phi'(p) are the eigenvalue-1 and eigenvalue-1/2
158+
eigenvectors of the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k)
159+
(the "connection coefficients"), normalized to reproduce constants and ramps.
160+
161+
Because the DWT requires uniform spacing, this method only accepts a scalar
162+
time step dt (not a vector of sample times). For non-uniformly sampled data,
163+
use :func:`rbfdiff` or :func:`splinediff` instead.
164+
165+
:param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
166+
:param float dt: uniform time step between samples.
167+
:param str wavelet: PyWavelets wavelet name. Must have a differentiable scaling
168+
function, so smoother wavelets give better derivatives: 'db8' (default) and
169+
'sym8' are best for noisy data; 'db4', 'sym4', and 'coif2' also work well.
170+
:param int level: decomposition depth. None (default) resolves to
171+
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals.
172+
:param float threshold: soft-thresholding scale factor in [0, inf).
173+
:param int axis: axis along which to differentiate (default 0).
174+
:param str mode: PyWavelets signal extension mode for the denoising transform.
175+
'periodization' keeps coefficient arrays compact. The derivative operator is
176+
periodic, so x_hat is antisymmetrically extended before it is applied (see below).
177+
:return: - **x_hat** (np.array) -- estimated (smoothed) x
178+
- **dxdt_hat** (np.array) -- estimated derivative of x
179+
"""
180+
if not np.isscalar(dt):
181+
raise ValueError("`dt` must be a scalar. The DWT requires uniformly sampled data. "
182+
"For variable step sizes, use rbfdiff or splinediff instead.")
183+
184+
# The Haar scaling function is a step, so it has no pointwise derivative and the
185+
# connection-coefficient operator below is undefined for it. Haar/db1 is the only
186+
# orthonormal wavelet with a 2-tap filter, so dec_len identifies it.
187+
if pywt.Wavelet(wavelet).dec_len == 2:
188+
raise ValueError("The Haar/db1 wavelet has a discontinuous (piecewise-constant) scaling "
189+
"function with no derivative, so it cannot be used to differentiate. Pick a smoother "
190+
"wavelet such as 'db4', 'sym4', or 'coif2'.")
191+
192+
N = x.shape[axis]
193+
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) # differentiation axis to front
194+
shape = x_work.shape # remember it to restore the input's dimensionality
195+
x_flat = x_work.reshape(N, -1) # rest of the dims flattened into columns
196+
Ne = 3 * N - 2 # length after the antisymmetric extension in step 2
197+
198+
# Build the wavelet-basis derivative operator (depends only on the grid and wavelet).
199+
# Sampling the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k) at integers makes
200+
# phi(p) the eigenvalue-1 and phi'(p) the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 h_{2p-q}.
201+
h = np.array(pywt.Wavelet(wavelet).rec_lo); h = h / h.sum() * np.sqrt(2) # refinement filter, integral of phi = 1
202+
L = len(h); p = np.arange(L) # phi is supported on the integers [0, L-1]
203+
shift = 2 * p[:, None] - p[None, :]
204+
T = np.where((shift >= 0) & (shift < L), np.sqrt(2) * h[np.clip(shift, 0, L - 1)], 0.0)
205+
evals, evecs = np.linalg.eig(T)
206+
phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]); phi /= phi.sum() # sum_p phi(p) = 1
207+
dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]); dphi /= np.dot(p, dphi)*-1 # sum_p p*phi'(p) = -1
208+
# Phi and Phi_prime hold circulant samples of phi and phi'/dt on the extended grid; both
209+
# share a common shift that cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic.
210+
rows, cols, phi_vals, dphi_vals = [], [], [], []
211+
m = np.arange(Ne)
212+
for offset, phi_p, dphi_p in zip(p, phi, dphi / dt):
213+
rows.extend(m); cols.extend((m - offset) % Ne); phi_vals.extend([phi_p]*Ne); dphi_vals.extend([dphi_p]*Ne)
214+
Phi = sparse.csr_matrix((phi_vals, (rows, cols)), shape=(Ne, Ne)).tocsc() # to invert
215+
Phi_prime = sparse.csr_matrix((dphi_vals, (rows, cols)), shape=(Ne, Ne)) # to apply
216+
217+
if level is None:
218+
level = min(pywt.dwt_max_level(N, wavelet), 5)
219+
220+
# 1. Denoise: DWT all columns at once, then soft-threshold the detail bands. The
221+
# noise level is estimated robustly per column from the finest details (coeffs[-1]).
222+
coeffs = pywt.wavedec(x_flat, wavelet, level=level, mode=mode, axis=0)
223+
sigma = np.maximum(np.median(np.abs(coeffs[-1]), axis=0) / 0.6745, 1e-10)
224+
thresh = threshold * sigma * np.sqrt(2 * np.log(N))
225+
coeffs = [coeffs[0]] + [pywt.threshold(c, thresh[np.newaxis, :], mode='soft') for c in coeffs[1:]]
226+
x_hat = pywt.waverec(coeffs, wavelet, mode=mode, axis=0)[:N]
227+
228+
# 2. The derivative operator is periodic, but x_hat usually isn't. Extend it
229+
# antisymmetrically (reflect through each endpoint: x[-1-k] -> 2*x[0]-x[1+k]) so the
230+
# periodic wrap is continuous in both value and slope, which keeps the derivative
231+
# accurate at the edges instead of spiking there. This is the odd-symmetry analog of
232+
# spectraldiff's even extension; a ramp extends to a ramp, so slopes survive exactly.
233+
left = 2 * x_hat[0] - x_hat[1:][::-1]
234+
right = 2 * x_hat[-1] - x_hat[:-1][::-1]
235+
x_ext = np.concatenate([left, x_hat, right], axis=0) # length 3N-2, original at [N-1:2N-1]
236+
237+
# 3. Differentiate the basis: recover the scaling coefficients a = Phi^-1 @ x_ext, then
238+
# apply the analytic basis derivative dxdt = Phi_prime @ a, and crop back to the original.
239+
a = sparse.linalg.spsolve(Phi, x_ext)
240+
dxdt_flat = (Phi_prime @ a.reshape(Ne, -1))[N - 1:2 * N - 1]
241+
242+
x_hat = np.moveaxis(x_hat.reshape(shape), 0, axis)
243+
dxdt_hat = np.moveaxis(dxdt_flat.reshape(shape), 0, axis)
244+
return x_hat, dxdt_hat

pynumdiff/tests/test_diff_methods.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..smooth_finite_difference import kerneldiff, mediandiff, meandiff, gaussiandiff, friedrichsdiff, butterdiff
66
from ..finite_difference import finitediff, first_order, second_order, fourth_order
77
from ..polynomial_fit import polydiff, savgoldiff, splinediff
8-
from ..basis_fit import spectraldiff, rbfdiff
8+
from ..basis_fit import spectraldiff, rbfdiff, waveletdiff
99
from ..total_variation_regularization import velocity, acceleration, jerk, iterative_velocity, smooth_acceleration, tvrdiff
1010
from ..kalman_smooth import rtsdiff, constant_velocity, constant_acceleration, constant_jerk, robustdiff
1111
from ..linear_model import lineardiff
@@ -51,6 +51,7 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
5151
(spline_irreg_step, {'degree':5, 's':2}),
5252
(spectraldiff, {'high_freq_cutoff':0.2}), (spectraldiff, [0.2]),
5353
(rbfdiff, {'sigma':0.5, 'lmbd':0.001}),
54+
(waveletdiff, {'wavelet':'db8', 'threshold':1.0}),
5455
(constant_velocity, {'r':1e-2, 'q':1e3}), (constant_velocity, [1e-2, 1e3]),
5556
(constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]),
5657
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
@@ -173,6 +174,12 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
173174
[(-2, -2), (0, 0), (0, -1), (0, 0)],
174175
[(0, 0), (2, 2), (0, 0), (2, 2)],
175176
[(1, 1), (3, 3), (1, 1), (3, 3)]],
177+
waveletdiff: [[(-15, -15), (-13, -13), (0, -1), (1, 0)],
178+
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
179+
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
180+
[(-3, -3), (-1, -1), (0, 0), (1, 1)],
181+
[(0, -1), (2, 2), (0, 0), (2, 2)],
182+
[(0, -1), (3, 3), (0, 0), (3, 3)]],
176183
velocity: [[(-25, -25), (-18, -19), (0, -1), (1, 0)],
177184
[(-12, -12), (-11, -12), (-1, -1), (-1, -2)],
178185
[(0, -1), (1, 0), (0, -1), (1, 0)],
@@ -327,6 +334,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
327334
(finitediff, {}),
328335
(polydiff, {'degree': 2, 'window_size': 5}),
329336
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
337+
(waveletdiff, {'wavelet': 'db8', 'threshold': 1.0}),
330338
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
331339
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
332340
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
@@ -343,6 +351,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
343351
kerneldiff: [(2, 1), (3, 2)],
344352
butterdiff: [(0, -1), (1, -1)],
345353
finitediff: [(0, -1), (1, -1)],
354+
waveletdiff: [(1, 0), (2, 2)],
346355
polydiff: [(1, -1), (1, 0)],
347356
savgoldiff: [(0, -1), (1, 1)],
348357
rtsdiff: [(1, -1), (1, 0)],

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ classifiers = [
2525
dependencies = [
2626
"numpy",
2727
"scipy",
28-
"matplotlib"
28+
"matplotlib",
29+
"pywavelets"
2930
]
3031

3132
[project.urls]

0 commit comments

Comments
 (0)