Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions pywt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,33 @@
wavelet packets signal decomposition and reconstruction module.
"""

from ._extensions._pywt import *
from ._functions import *
from ._multilevel import *
from ._multidim import *
from ._thresholding import *
from ._wavelet_packets import *
from ._dwt import *
from ._swt import *
from ._cwt import *
from ._mra import *
from ._extensions._pywt import Modes, ContinuousWavelet, families, Wavelet, wavelist, DiscreteContinuousWavelet
from ._functions import integrate_wavelet, central_frequency, scale2frequency, frequency2scale, qmf, orthogonal_filter_bank, intwave, centrfrq, scal2frq, orthfilt
from ._multilevel import wavedec, waverec, wavedec2, waverec2, wavedecn, waverecn, coeffs_to_array, array_to_coeffs, ravel_coeffs, unravel_coeffs, dwtn_max_level, wavedecn_size, wavedecn_shapes, fswavedecn, fswaverecn, FswavedecnResult
from ._multidim import dwt2, idwt2, dwtn, idwtn
from ._thresholding import threshold, threshold_firm
from ._wavelet_packets import BaseNode, Node, WaveletPacket, Node2D, WaveletPacket2D, NodeND, WaveletPacketND
from ._dwt import dwt, idwt, downcoef, upcoef, dwt_max_level, dwt_coeff_len, pad
from ._swt import swt, swt_max_level, iswt, swt2, iswt2, swtn, iswtn
from ._cwt import cwt
from ._mra import mra, mra2, mran, imra, imra2, imran
from .data import aero, ascent, camera, ecg, nino, demo_signal

from . import data

__all__ = [s for s in dir() if not s.startswith('_')]
try:
# In Python 2.x the name of the tempvar leaks out of the list
# comprehension. Delete it to not make it show up in the main namespace.
del s
except NameError:
pass
__all__ = ["ContinuousWavelet", "families", "Modes", "Wavelet", "wavelist",
"DiscreteContinuousWavelet", "integrate_wavelet",
"central_frequency", "scale2frequency", "frequency2scale", "qmf",
"orthogonal_filter_bank", "intwave", "centrfrq", "scal2frq",
"orthfilt", "wavedec", "waverec", "wavedec2", "waverec2",
"wavedecn", "waverecn", "coeffs_to_array", "array_to_coeffs",
"ravel_coeffs", "unravel_coeffs", "dwtn_max_level",
"wavedecn_size", "wavedecn_shapes", "fswavedecn", "fswaverecn",
"FswavedecnResult", "dwt2", "idwt2", "dwtn", "idwtn", "threshold",
"threshold_firm", "BaseNode", "Node", "WaveletPacket", "Node2D",
"WaveletPacket2D", "NodeND", "WaveletPacketND", "dwt", "idwt",
"downcoef", "upcoef", "dwt_max_level", "dwt_coeff_len", "pad",
"swt", "swt_max_level", "iswt", "swt2", "iswt2", "swtn", "iswtn",
"cwt", "mra", "mra2", "mran", "imra", "imra2", "imran", "aero",
"ascent", "camera", "ecg", "nino", "demo_signal"]

from pywt.version import version as __version__

Expand Down
4 changes: 4 additions & 0 deletions pywt/_extensions/_cwt.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from _pywt import ContinuousWavelet, DataT
from numpy.typing import NDArray

def cwt_psi_single(data: NDArray[DataT], wavelet: ContinuousWavelet, output_len: int) -> NDArray[DataT] | tuple[NDArray[DataT], NDArray[DataT]]: ...
12 changes: 12 additions & 0 deletions pywt/_extensions/_dwt.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from numpy.typing import NDArray

from ._pywt import CDataT, ModeInt, Wavelet

def dwt_max_level(data_len: int, filter_len: int) -> int: ...
def dwt_coeff_len(size_t: int, filter_len: int, mode: ModeInt) -> int: ...
def dwt_single(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
def dwt_axis(data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> tuple[NDArray[CDataT], NDArray[CDataT]]: ...
def idwt_single(cA: NDArray[CDataT], cD: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt) -> NDArray[CDataT]: ...
def idwt_axis(coefs_a: NDArray[CDataT], coefs_d: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, axis: int = 0) -> NDArray[CDataT]: ...
def upcoef(do_rec_a: bool, coeffs: NDArray[CDataT], wavelet: Wavelet, level: int, take: int) -> NDArray[CDataT]: ...
def downcoef(do_dec_a: bool, data: NDArray[CDataT], wavelet: Wavelet, mode: ModeInt, level: int) -> NDArray[CDataT]: ...
185 changes: 185 additions & 0 deletions pywt/_extensions/_pywt.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from enum import IntEnum
from typing import Any, Literal, Optional, TypeAlias, TypeVar

import numpy as np

_WaveletFamily = Literal[
"haar",
"db",
"sym",
"coif",
"bior",
"rbio",
"dmey",
"gaus",
"mexh",
"morl",
"cgau",
"shan",
"fbsp",
"cmor",
]

DataT = TypeVar("DataT", bound=np.float32 | np.float64)

CDataT = TypeVar(
"CDataT", bound=np.float32 | np.float64 | np.complex64 | np.complex128
)

_Kind: TypeAlias = Literal["all", "continuous", "discrete"]

_Symmetry = Literal[
"asymmetric",
"near symmetric",
"symmetric",
"anti-symmetric",
"unknown",
]

class MODE(IntEnum):
MODE_INVALID = -1
MODE_ZEROPAD = 0
MODE_SYMMETRIC = 1
MODE_CONSTANT_EDGE = 2
MODE_SMOOTH = 3
MODE_PERIODIC = 4
MODE_PERIODIZATION = 5
MODE_REFLECT = 6
MODE_ANTISYMMETRIC = 7
MODE_ANTIREFLECT = 8
MODE_MAX = 9

ModeInt = MODE | Literal[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

ModeName = Literal[
"zero",
"constant",
"symmetric",
"reflect",
"periodic",
"smooth",
"periodization",
"antisymmetric",
"antireflect",
]

Mode = MODE | ModeName

class _Modes:
zero: int
constant: int
symmetric: int
reflect: int
periodic: int
smooth: int
periodization: int
antisymmetric: int
antireflect: int

modes: list[ModeName]

def from_object(self, mode: Mode) -> int: ...

Modes = _Modes()

def wavelist(family: _WaveletFamily | None = None, kind: _Kind = "all") -> list[str]: ...
def families(short: bool = True) -> list[str]: ...

class Wavelet:
def __init__(self, name: str = "", filter_bank: Any = None) -> None: ...
def __len__(self) -> int: ...
@property
def name(self) -> str: ...
@property
def dec_lo(self) -> list[float]: ...
@property
def dec_hi(self) -> list[float]: ...
@property
def rec_lo(self) -> list[float]: ...
@property
def rec_hi(self) -> list[float]: ...
@property
def rec_len(self) -> int: ...
@property
def dec_len(self) -> int: ...
@property
def family_number(self) -> int: ...
@property
def family_name(self) -> str: ...
@property
def short_family_name(self) -> str: ...
@property
def orthogonal(self) -> bool: ...
@orthogonal.setter
def orthogonal(self, value: bool) -> None: ...
@property
def biorthogonal(self) -> bool: ...
@biorthogonal.setter
def biorthogonal(self, value: bool) -> None: ...
@property
def symmetry(self) -> _Symmetry: ...
@property
def vanishing_moments_psi(self) -> int | None: ...
@property
def vanishing_moments_phi(self) -> int | None: ...
@property
def filter_bank(
self,
) -> tuple[list[float], list[float], list[float], list[float]]: ...
def get_filters_coeffs(
self,
) -> tuple[list[float], list[float], list[float], list[float]]: ...
@property
def inverse_filter_bank(
self,
) -> tuple[list[float], list[float], list[float], list[float]]: ...
def get_reverse_filters_coeffs(
self,
) -> tuple[list[float], list[float], list[float], list[float]]: ...

class ContinuousWavelet:
def __init__(self, name: str = "", dtype: DataT = np.float64) -> None: ...
@property
def family_number(self) -> int: ...
@property
def family_name(self) -> str: ...
@property
def short_family_name(self) -> str: ...
@property
def orthogonal(self) -> bool: ...
@orthogonal.setter
def orthogonal(self, value: bool) -> None: ...
@property
def biorthogonal(self) -> bool: ...
@biorthogonal.setter
def biorthogonal(self, value: bool) -> None: ...
@property
def complex_cwt(self) -> bool: ...
@complex_cwt.setter
def complex_cwt(self, value: bool) -> None: ...
@property
def lower_bound(self) -> float | None: ...
@lower_bound.setter
def lower_bound(self, value: float) -> None: ...
@property
def upper_bound(self) -> float | None: ...
@upper_bound.setter
def upper_bound(self, value: float) -> None: ...
@property
def center_frequency(self) -> float | None: ...
@center_frequency.setter
def center_frequency(self, value: float) -> None: ...
@property
def bandwidth_frequency(self) -> float | None: ...
@bandwidth_frequency.setter
def bandwidth_frequency(self, value: float) -> None: ...
@property
def fbsp_order(self) -> int | None: ...
@fbsp_order.setter
def fbsp_order(self, value: int) -> None: ...
@property
def symmetry(self) -> _Symmetry: ...

def DiscreteContinuousWavelet(
name: str = "", filter_bank: Any = None
) -> Wavelet | ContinuousWavelet: ...
7 changes: 7 additions & 0 deletions pywt/_extensions/_swt.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from numpy.typing import NDArray

from ._pywt import CDataT, Wavelet

def swt_max_level(input_len: int) -> int: ...
def swt(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, trim_approx: bool = False) -> NDArray[CDataT]: ...
def swt_axis(data: NDArray[CDataT], wavelet: Wavelet, level: int, start_level: int, axis: int = 0, trim_approx: bool = False) -> NDArray[CDataT]: ...