From 5333fd737eb8ba04ccd570305f613cdecf5b64ba Mon Sep 17 00:00:00 2001 From: Niklas Mueller Date: Wed, 15 Apr 2026 17:13:05 +0200 Subject: [PATCH 1/2] add stub files for Cython modules in pywt._extensions --- pywt/__init__.py | 45 +++++---- pywt/_extensions/_cwt.pyi | 4 + pywt/_extensions/_dwt.pyi | 12 +++ pywt/_extensions/_pywt.pyi | 185 +++++++++++++++++++++++++++++++++++++ pywt/_extensions/_swt.pyi | 7 ++ 5 files changed, 234 insertions(+), 19 deletions(-) create mode 100644 pywt/_extensions/_cwt.pyi create mode 100644 pywt/_extensions/_dwt.pyi create mode 100644 pywt/_extensions/_pywt.pyi create mode 100644 pywt/_extensions/_swt.pyi diff --git a/pywt/__init__.py b/pywt/__init__.py index 669b06b6..5eb6d188 100644 --- a/pywt/__init__.py +++ b/pywt/__init__.py @@ -10,26 +10,33 @@ wavelet packets signal decomposition and reconstruction module. """ -from ._extensions._pywt import * -from ._functions import * -from ._multilevel import * -from ._multidim import * -from ._thresholding import * -from ._wavelet_packets import * -from ._dwt import * -from ._swt import * -from ._cwt import * -from ._mra import * +from ._extensions._pywt import Modes, ContinuousWavelet, families, Wavelet, wavelist, DiscreteContinuousWavelet +from ._functions import integrate_wavelet, central_frequency, scale2frequency, frequency2scale, qmf, orthogonal_filter_bank, intwave, centrfrq, scal2frq, orthfilt +from ._multilevel import wavedec, waverec, wavedec2, waverec2, wavedecn, waverecn, coeffs_to_array, array_to_coeffs, ravel_coeffs, unravel_coeffs, dwtn_max_level, wavedecn_size, wavedecn_shapes, fswavedecn, fswaverecn, FswavedecnResult +from ._multidim import dwt2, idwt2, dwtn, idwtn +from ._thresholding import threshold, threshold_firm +from ._wavelet_packets import BaseNode, Node, WaveletPacket, Node2D, WaveletPacket2D, NodeND, WaveletPacketND +from ._dwt import dwt, idwt, downcoef, upcoef, dwt_max_level, dwt_coeff_len, pad +from ._swt import swt, swt_max_level, iswt, swt2, iswt2, swtn, iswtn +from ._cwt import cwt +from ._mra import mra, mra2, mran, imra, imra2, imran +from .data import aero, ascent, camera, ecg, nino, demo_signal -from . import data - -__all__ = [s for s in dir() if not s.startswith('_')] -try: - # In Python 2.x the name of the tempvar leaks out of the list - # comprehension. Delete it to not make it show up in the main namespace. - del s -except NameError: - pass +__all__ = ["ContinuousWavelet", "families", "Modes", "Wavelet", "wavelist", + "DiscreteContinuousWavelet", "integrate_wavelet", + "central_frequency", "scale2frequency", "frequency2scale", "qmf", + "orthogonal_filter_bank", "intwave", "centrfrq", "scal2frq", + "orthfilt", "wavedec", "waverec", "wavedec2", "waverec2", + "wavedecn", "waverecn", "coeffs_to_array", "array_to_coeffs", + "ravel_coeffs", "unravel_coeffs", "dwtn_max_level", + "wavedecn_size", "wavedecn_shapes", "fswavedecn", "fswaverecn", + "FswavedecnResult", "dwt2", "idwt2", "dwtn", "idwtn", "threshold", + "threshold_firm", "BaseNode", "Node", "WaveletPacket", "Node2D", + "WaveletPacket2D", "NodeND", "WaveletPacketND", "dwt", "idwt", + "downcoef", "upcoef", "dwt_max_level", "dwt_coeff_len", "pad", + "swt", "swt_max_level", "iswt", "swt2", "iswt2", "swtn", "iswtn", + "cwt", "mra", "mra2", "mran", "imra", "imra2", "imran", "aero", + "ascent", "camera", "ecg", "nino", "demo_signal"] from pywt.version import version as __version__ diff --git a/pywt/_extensions/_cwt.pyi b/pywt/_extensions/_cwt.pyi new file mode 100644 index 00000000..ccc28098 --- /dev/null +++ b/pywt/_extensions/_cwt.pyi @@ -0,0 +1,4 @@ +from _pywt import ContinuousWavelet, DataT +from numpy.typing import NDArray + +def cwt_psi_single(data: NDArray[DataT], wavelet: ContinuousWavelet, output_len: int) -> NDArray[DataT] | tuple[NDArray[DataT], NDArray[DataT]]: ... diff --git a/pywt/_extensions/_dwt.pyi b/pywt/_extensions/_dwt.pyi new file mode 100644 index 00000000..35031091 --- /dev/null +++ b/pywt/_extensions/_dwt.pyi @@ -0,0 +1,12 @@ +from numpy.typing import NDArray + +from pywt import MODE, CDataT, Wavelet + +def dwt_max_level(data_len: int, filter_len: int) -> int: ... +def dwt_coeff_len(size_t: int, filter_len: int, mode: MODE) -> int: ... +def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... +def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... +def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> NDArray[CDataT]: ... +def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> NDArray[CDataT]: ... +def upcoef(do_rec_a: bool, coeffs: NDArray[CDataT], wavelet: Wavelet, level: int, take: int) -> NDArray[CDataT]: ... +def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, level: int) -> NDArray[CDataT]: ... diff --git a/pywt/_extensions/_pywt.pyi b/pywt/_extensions/_pywt.pyi new file mode 100644 index 00000000..9f4e0433 --- /dev/null +++ b/pywt/_extensions/_pywt.pyi @@ -0,0 +1,185 @@ +from enum import IntEnum +from typing import Any, Literal, Optional, TypeVar + +import numpy as np + +_WaveletFamily = Literal[ + "haar", + "db", + "sym", + "coif", + "bior", + "rbio", + "dmey", + "gaus", + "mexh", + "morl", + "cgau", + "shan", + "fbsp", + "cmor", +] + +DataT = TypeVar("DataT", np.float32, np.float64) + +CDataT = TypeVar( + "CDataT", + np.float32, + np.float64, + np.complex64, + np.complex128, +) + +_Kind = Literal["all", "continuous", "discrete"] + +_Symmetry = Literal[ + "asymmetric", + "near symmetric", + "symmetric", + "anti-symmetric", + "unknown", +] + +class MODE(IntEnum): + MODE_INVALID = -1 + MODE_ZEROPAD = 0 + MODE_SYMMETRIC = 1 + MODE_CONSTANT_EDGE = 2 + MODE_SMOOTH = 3 + MODE_PERIODIC = 4 + MODE_PERIODIZATION = 5 + MODE_REFLECT = 6 + MODE_ANTISYMMETRIC = 7 + MODE_ANTIREFLECT = 8 + MODE_MAX = 9 + +ModeName = Literal[ + "zero", + "constant", + "symmetric", + "reflect", + "periodic", + "smooth", + "periodization", + "antisymmetric", + "antireflect", +] + +Mode = MODE | ModeName + +class _Modes: + zero: int + constant: int + symmetric: int + reflect: int + periodic: int + smooth: int + periodization: int + antisymmetric: int + antireflect: int + + modes: list[ModeName] + + def from_object(self, mode: Mode) -> int: ... + +Modes = _Modes() + +def wavelist(family: _WaveletFamily | None = ..., kind: _Kind = ...) -> list[str]: ... +def families(short: bool = ...) -> list[str]: ... + +class Wavelet: + def __init__(self, name: str = ..., filter_bank: Any = ...) -> None: ... + def __len__(self) -> int: ... + @property + def name(self) -> str: ... + @property + def dec_lo(self) -> list[float]: ... + @property + def dec_hi(self) -> list[float]: ... + @property + def rec_lo(self) -> list[float]: ... + @property + def rec_hi(self) -> list[float]: ... + @property + def rec_len(self) -> int: ... + @property + def dec_len(self) -> int: ... + @property + def family_number(self) -> int: ... + @property + def family_name(self) -> str: ... + @property + def short_family_name(self) -> str: ... + @property + def orthogonal(self) -> bool: ... + @orthogonal.setter + def orthogonal(self, value: bool) -> None: ... + @property + def biorthogonal(self) -> bool: ... + @biorthogonal.setter + def biorthogonal(self, value: bool) -> None: ... + @property + def symmetry(self) -> _Symmetry: ... + @property + def vanishing_moments_psi(self) -> int | None: ... + @property + def vanishing_moments_phi(self) -> int | None: ... + @property + def filter_bank( + self, + ) -> tuple[list[float], list[float], list[float], list[float]]: ... + def get_filters_coeffs( + self, + ) -> tuple[list[float], list[float], list[float], list[float]]: ... + @property + def inverse_filter_bank( + self, + ) -> tuple[list[float], list[float], list[float], list[float]]: ... + def get_reverse_filters_coeffs( + self, + ) -> tuple[list[float], list[float], list[float], list[float]]: ... + +class ContinuousWavelet: + def __init__(self, name: str = ..., dtype: DataT = ...) -> None: ... + @property + def family_number(self) -> int: ... + @property + def family_name(self) -> str: ... + @property + def short_family_name(self) -> str: ... + @property + def orthogonal(self) -> bool: ... + @orthogonal.setter + def orthogonal(self, value: bool) -> None: ... + @property + def biorthogonal(self) -> bool: ... + @biorthogonal.setter + def biorthogonal(self, value: bool) -> None: ... + @property + def complex_cwt(self) -> bool: ... + @complex_cwt.setter + def complex_cwt(self, value: bool) -> None: ... + @property + def lower_bound(self) -> float | None: ... + @lower_bound.setter + def lower_bound(self, value: float) -> None: ... + @property + def upper_bound(self) -> float | None: ... + @upper_bound.setter + def upper_bound(self, value: float) -> None: ... + @property + def center_frequency(self) -> float | None: ... + @center_frequency.setter + def center_frequency(self, value: float) -> None: ... + @property + def bandwidth_frequency(self) -> float | None: ... + @bandwidth_frequency.setter + def bandwidth_frequency(self, value: float) -> None: ... + @property + def fbsp_order(self) -> int | None: ... + @fbsp_order.setter + def fbsp_order(self, value: int) -> None: ... + @property + def symmetry(self) -> _Symmetry: ... + +def DiscreteContinuousWavelet(name: str = ..., filter_bank: Any = ...) -> Wavelet | ContinuousWavelet : ... diff --git a/pywt/_extensions/_swt.pyi b/pywt/_extensions/_swt.pyi new file mode 100644 index 00000000..212c0344 --- /dev/null +++ b/pywt/_extensions/_swt.pyi @@ -0,0 +1,7 @@ +from numpy.typing import NDArray + +from pywt import CDataT, Wavelet + +def swt_max_level(input_len: int) -> int: ... +def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = ...) -> NDArray[CDataT]: ... +def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = ..., trim_approx: bool = ...) -> NDArray[CDataT]: ... From b14096a3533976d8a32f2ec5fddb4ca0d3dce1ee Mon Sep 17 00:00:00 2001 From: Niklas Mueller Date: Thu, 16 Apr 2026 18:13:48 +0200 Subject: [PATCH 2/2] update stub files - Added default values in stub files. - Added ModeInt type union to allow integer values where MODE enum was expected. - Used upper bounds for DataT and CDataT types instead of variable constraints. --- pywt/_extensions/_dwt.pyi | 14 +++++++------- pywt/_extensions/_pywt.pyi | 26 +++++++++++++------------- pywt/_extensions/_swt.pyi | 6 +++--- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pywt/_extensions/_dwt.pyi b/pywt/_extensions/_dwt.pyi index 35031091..3ddae78e 100644 --- a/pywt/_extensions/_dwt.pyi +++ b/pywt/_extensions/_dwt.pyi @@ -1,12 +1,12 @@ from numpy.typing import NDArray -from pywt import MODE, CDataT, Wavelet +from ._pywt import CDataT, ModeInt, Wavelet def dwt_max_level(data_len: int, filter_len: int) -> int: ... -def dwt_coeff_len(size_t: int, filter_len: int, mode: MODE) -> int: ... -def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... -def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... -def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: MODE) -> NDArray[CDataT]: ... -def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: MODE, axis: int = ...) -> NDArray[CDataT]: ... +def dwt_coeff_len(size_t: int, filter_len: int, mode: ModeInt) -> int: ... +def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... +def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ... +def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> NDArray[CDataT]: ... +def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> NDArray[CDataT]: ... def upcoef(do_rec_a: bool, coeffs: NDArray[CDataT], wavelet: Wavelet, level: int, take: int) -> NDArray[CDataT]: ... -def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: MODE, level: int) -> NDArray[CDataT]: ... +def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, level: int) -> NDArray[CDataT]: ... diff --git a/pywt/_extensions/_pywt.pyi b/pywt/_extensions/_pywt.pyi index 9f4e0433..c0886c47 100644 --- a/pywt/_extensions/_pywt.pyi +++ b/pywt/_extensions/_pywt.pyi @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Literal, Optional, TypeVar +from typing import Any, Literal, Optional, TypeAlias, TypeVar import numpy as np @@ -20,17 +20,13 @@ _WaveletFamily = Literal[ "cmor", ] -DataT = TypeVar("DataT", np.float32, np.float64) +DataT = TypeVar("DataT", bound=np.float32 | np.float64) CDataT = TypeVar( - "CDataT", - np.float32, - np.float64, - np.complex64, - np.complex128, + "CDataT", bound=np.float32 | np.float64 | np.complex64 | np.complex128 ) -_Kind = Literal["all", "continuous", "discrete"] +_Kind: TypeAlias = Literal["all", "continuous", "discrete"] _Symmetry = Literal[ "asymmetric", @@ -53,6 +49,8 @@ class MODE(IntEnum): MODE_ANTIREFLECT = 8 MODE_MAX = 9 +ModeInt = MODE | Literal[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ModeName = Literal[ "zero", "constant", @@ -84,11 +82,11 @@ class _Modes: Modes = _Modes() -def wavelist(family: _WaveletFamily | None = ..., kind: _Kind = ...) -> list[str]: ... -def families(short: bool = ...) -> list[str]: ... +def wavelist(family: _WaveletFamily | None = None, kind: _Kind = "all") -> list[str]: ... +def families(short: bool = True) -> list[str]: ... class Wavelet: - def __init__(self, name: str = ..., filter_bank: Any = ...) -> None: ... + def __init__(self, name: str = "", filter_bank: Any = None) -> None: ... def __len__(self) -> int: ... @property def name(self) -> str: ... @@ -140,7 +138,7 @@ class Wavelet: ) -> tuple[list[float], list[float], list[float], list[float]]: ... class ContinuousWavelet: - def __init__(self, name: str = ..., dtype: DataT = ...) -> None: ... + def __init__(self, name: str = "", dtype: DataT = np.float64) -> None: ... @property def family_number(self) -> int: ... @property @@ -182,4 +180,6 @@ class ContinuousWavelet: @property def symmetry(self) -> _Symmetry: ... -def DiscreteContinuousWavelet(name: str = ..., filter_bank: Any = ...) -> Wavelet | ContinuousWavelet : ... +def DiscreteContinuousWavelet( + name: str = "", filter_bank: Any = None +) -> Wavelet | ContinuousWavelet: ... diff --git a/pywt/_extensions/_swt.pyi b/pywt/_extensions/_swt.pyi index 212c0344..a845be75 100644 --- a/pywt/_extensions/_swt.pyi +++ b/pywt/_extensions/_swt.pyi @@ -1,7 +1,7 @@ from numpy.typing import NDArray -from pywt import CDataT, Wavelet +from ._pywt import CDataT, Wavelet def swt_max_level(input_len: int) -> int: ... -def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = ...) -> NDArray[CDataT]: ... -def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = ..., trim_approx: bool = ...) -> NDArray[CDataT]: ... +def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = False) -> NDArray[CDataT]: ... +def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = 0, trim_approx: bool = False) -> NDArray[CDataT]: ...