From cbe3b9f3bd67a8c75ed324e1578013093df44aae Mon Sep 17 00:00:00 2001 From: Bhavya Gupta Date: Mon, 13 Apr 2026 09:34:29 -0500 Subject: [PATCH 1/5] adding heterodyne functionality to aframe --- libs/architectures/architectures/__init__.py | 1 + .../architectures/architectures/supervised.py | 38 ++++ libs/utils/utils/preprocessing.py | 183 ++++++++++++++++++ projects/export/export_heterodyne.yaml | 38 ++++ projects/export/pyproject.toml | 2 +- projects/export/uv.lock | 8 +- projects/train/configs/time_heterodyne.yaml | 148 ++++++++++++++ projects/train/pyproject.toml | 2 +- projects/train/train/cli.py | 10 + .../train/train/data/supervised/__init__.py | 5 +- .../train/data/supervised/time_domain.py | 113 ++++++++++- projects/train/uv.lock | 10 +- 12 files changed, 546 insertions(+), 12 deletions(-) create mode 100644 projects/export/export_heterodyne.yaml create mode 100644 projects/train/configs/time_heterodyne.yaml diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 7ff6db14c..350160f02 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,4 +6,5 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, SupervisedTimeSpectrogramResNet, + SupervisedHeterodyneTimeDomainResNet, ) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 1ade07fff..2e28e564e 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -280,3 +280,41 @@ def forward(self, X, X_spec): time_domain_output = self.time_domain_resnet(X) spec_domain_output = self.spectrogram_resnet(X_spec) return time_domain_output, spec_domain_output + + +class SupervisedHeterodyneTimeDomainResNet(SupervisedArchitecture): + """ + Time Domain ResNet that processes a Heterodyned timeseries. + + Args: + num_chirp_masses (int): + Number of chirp masses used to define the input channel + dimension (in_channels = num_ifos x num_chirp_masses). + """ + def __init__( + self, + num_ifos: int, + num_chirp_masses: int, + layers: list[int], + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ) -> None: + super().__init__() + self.time_domain_resnet = ResNet1D( + in_channels=num_ifos*num_chirp_masses, + layers=layers, + classes=1, + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + def forward(self, X): + return self.time_domain_resnet(X) diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index 7bc4bb78f..2fd49ea4f 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -1,13 +1,16 @@ from collections.abc import Callable +import math import torch from torch import Tensor +from typing import Literal from ml4gw.transforms import ( SpectralDensity, Whiten, Decimator, SingleQTransform, + Heterodyne, ) from ml4gw.utils.slicing import unfold_windows @@ -589,3 +592,183 @@ def forward(self, x: Tensor) -> Tensor: # first input is timeseries and second input is spectrogram return x[1], spec + + +class HeterodyneTimeDomainPreprocessor(torch.nn.Module): + """ + Calculate the PSDs and whiten an entire batch of kernels at once + and then heterodyne the timeseries. + + Applies heterodyne transform to the desired `kernel_length` of the + strain. If `keep_last_n_seconds` is passed, returns only the final + portion of the heterodyned strain. + + Args: + kernel_length (float): Length of output kernels in seconds. + sample_rate (float): Input sampling rate in Hz. + inference_sampling_rate (float): Sampling rate of network output in Hz. + Determines the overlap between kernels. + batch_size (int): Number of kernels to extract from input. + fduration (float): Duration of the whitening filter in seconds + fftlength (float): FFT length for PSD calculation in seconds. + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If > 0, only keep the last `n` seconds of the kernel_length. If 0, + keep the full kernel_length. + highpass (float, optional): Highpass frequency in Hz. Applied during + whitening. Defaults to None. + lowpass (float, optional): Lowpass frequency in Hz. Applied during + whitening. Defaults to None. + + Example: + >>> preprocessor = HeterodyneTimeDomainPreprocessor( + ... kernel_length=8, sample_rate=2048, + ... inference_sampling_rate=16, batch_size=128, + ... fduration=2, fftlength=2, chirp_mass_low=1.0, + ... chirp_mass_high=2.5, num_chirp_masses=100, + ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, + ... ) + >>> x = torch.randn(2, 16384) # (channels, time) + >>> X = preprocessor(x) # shape: (batch_size, channels x num_chirp_masses, kernel_size) + """ + + def __init__( + self, + kernel_length: float, + sample_rate: float, + inference_sampling_rate: float, + batch_size: int, + fduration: float, + fftlength: float, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = 0.0, + highpass: float | None = None, + lowpass: float | None = None, + ) -> None: + super().__init__() + # Calculate stride between kernels based on inference sampling rate + self.stride_size = int(sample_rate / inference_sampling_rate) + # Convert kernel length to samples + self.kernel_size = int(kernel_length * sample_rate) + + # do length calculations in units of samples, + # then convert back to length to guard for intification + strides = (batch_size - 1) * self.stride_size + fsize = int(fduration * sample_rate) + size = strides + self.kernel_size + fsize + length = size / sample_rate + + # Initialize PSD estimator with calculated total length + self.psd_estimator = PsdEstimator( + length, + sample_rate, + fftlength=fftlength, + overlap=None, + average="median", + fast=highpass is not None, + ) + # Initialize whitening module + self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + self.keep_last_n_samples = int( + keep_last_n_seconds * sample_rate + ) + + self.heterodyne_transform = Heterodyne( + sample_rate=int(sample_rate), + kernel_length=int(kernel_length), + chirp_mass=self.chirp_mass_grid, + return_type="time" + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Estimate PSD, whiten data, and unfold kernels. + + Args: + x (Tensor): Input data of shape (batch, channels, time) or + (channels, time). + + Returns: + Tensor: Extracted and optionally augmented kernels of shape + (batch_size, channels, kernel_size). + + If return_whitened=True, returns tuple of (kernels, whitened_data). + + Raises: + ValueError: If input is not 2 or 3 dimensional. + """ + # Determine number of channels for later reshaping + if x.ndim == 3: + num_channels = x.size(1) + elif x.ndim == 2: + num_channels = x.size(0) + else: + raise ValueError( + "Expected input to be either 2 or 3 dimensional, " + "but found shape {}".format(x.shape) + ) + + # Estimate PSD and prepare data + x, psd = self.psd_estimator(x.double()) + # Apply whitening using estimated PSD + whitened = self.whitener(x, psd) + + # unfold x and then put it into the expected shape. + # Note that if x has both signal and background + # batch elements, they will be interleaved along + # the batch dimension after unfolding + x = unfold_windows(whitened, self.kernel_size, self.stride_size) + # Reshape to (batch_size, channels, kernel_size) + x = x.reshape(-1, num_channels, self.kernel_size) + # Heterodyne the whitened timeseries + x = self.heterodyne_transform(x) + # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) to + # (batch_size, channels x num_chirp_mass, kernel_size) + _B, _C, _M, _T = x.shape + x = x.reshape(_B, _C*_M, _T) + # Returning the desired length of heterodyned strain in the time dimension + if self.keep_last_n_samples > 0: + return x[..., -self.keep_last_n_samples:] + else: + return x diff --git a/projects/export/export_heterodyne.yaml b/projects/export/export_heterodyne.yaml new file mode 100644 index 000000000..f51b11810 --- /dev/null +++ b/projects/export/export_heterodyne.yaml @@ -0,0 +1,38 @@ +# commented args (i.e. comments with colon at end +# to illustrate a value needs to be filled out) +# represent values filled out +# by the export task at run time. To build a functional +# standalone config, add these in yourself. + +logfile: +weights: +batch_file: +repository_directory: +num_ifos: 2 +kernel_length: 8.0 +inference_sampling_rate: 4 +sample_rate: 2048 +batch_size: 128 +fduration: 2 +psd_length: 64 +preprocessor: + class_path: utils.preprocessing.HeterodyneTimeDomainPreprocessor + init_args: + kernel_length: 8.0 + sample_rate: 2048 + inference_sampling_rate: 4 + batch_size: 128 + fduration: 2 + fftlength: 2 + chirp_mass_low: 1.0 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: 4.0 + highpass: 32.0 +streams_per_gpu: 12 +# num_outputs: +aframe_instances: 1 +platform: TENSORRT +clean: true +verbose: true diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index 7be57c0ae..609e4e4fe 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ - "ml4gw>=0.7.7", + "ml4gw>=0.7.13", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]>=0.2.1", diff --git a/projects/export/uv.lock b/projects/export/uv.lock index 53c7bf108..d1b58c1a0 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.7.13" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, { name = "nvidia-cudnn-cu11", specifier = "==8.9.6.50" }, { name = "tensorrt", specifier = "==8.5.2.2" }, @@ -972,7 +972,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.7.13" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -981,9 +981,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/af/bab88ca4f54386735a64502a990f1bc4edb6a0353aa2b910efd9aa244919/ml4gw-0.7.13.tar.gz", hash = "sha256:4e6264fcdb9cbf5ed6a83910a231a946770687bbc6c576f8e6dc811af7102f24", size = 121893, upload-time = "2026-04-07T00:22:25.566Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/1499d3835e9fb9adeeb8b5b5107dcdee79c37d6a61cd9db912fe59100efa/ml4gw-0.7.13-py3-none-any.whl", hash = "sha256:53132028674bb079759f97b664ba8d1ab1e53ecb21a50660b9015ae7c518bde0", size = 132944, upload-time = "2026-04-07T00:22:24.325Z" }, ] [[package]] diff --git a/projects/train/configs/time_heterodyne.yaml b/projects/train/configs/time_heterodyne.yaml new file mode 100644 index 000000000..add60628b --- /dev/null +++ b/projects/train/configs/time_heterodyne.yaml @@ -0,0 +1,148 @@ +# commented args represent values filled out +# by train task at run time. To build a functional +# standalone config, add these in. + +# To start training from a checkpoint, uncomment the below argument +# and specify a path to the desired checkpoint +# ckpt_path: "" +model: + class_path: train.model.SupervisedAframe + init_args: + # architecture + arch: + class_path: architectures.supervised.SupervisedHeterodyneTimeDomainResNet + init_args: + layers: [3, 4, 6, 3] + kernel_size: 25 + norm_layer: + class_path: ml4gw.nn.norm.GroupNorm1DGetter + init_args: + groups: 16 + metric: + class_path: train.metrics.TimeSlideAUROC + init_args: + max_fpr: 1e-3 + pool_length: 8 + + # optimization params + weight_decay: 3e-5 + learning_rate: 4e-4 + pct_lr_ramp: 0.115 + # early stop +data: + class_path: train.data.supervised.HeterodyneTimeDomainSupervisedAframeDataset + init_args: + # loading args + background_dir: + waveforms_dir: + ifos: [H1,L1] + sample_rate: 2048 + batches_per_epoch: 100 + num_files_per_batch: 10 + chunk_size: 1000 + chunks_per_epoch: 20 + # preprocessing args + batch_size: 100 + kernel_length: 8 + psd_length: 20 + fduration: 2 + # augmentation args + waveform_prob: 0.277 + swap_prob: 0.014 + mute_prob: 0.055 + left_pad: 6.95 + right_pad: 0.05 + # heterodyne preprocessing args + chirp_mass_low: 1 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: 4 + # highpass: + # lowpass: + fftlength: 2 + snr_sampler: + class_path: ml4gw.distributions.PowerLaw + init_args: + minimum: 8 + maximum: 100 + index: -3 + # curriculum learning for snr sampler + # snr_sampler: + # class_path: train.augmentations.SnrSampler + # init_args: + # max_min_snr: 30 + # min_min_snr: 8 + # max_snr: 100 + # alpha: -3 + # decay_steps: 600 + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + training_waveform_path: + val_waveform_file: + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 6.28318 + validate_args: false + # validation args + valid_stride: 0.5 + num_valid_views: 5 + valid_livetime: 57600 + +trainer: + # by default, use a local CSV logger. + # note that you can use multiple loggers! + # Options in train task for appending + # a wandb logger for remote logging. + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: + name: + version: + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: "valid_auroc" + mode: "max" + patience: 500 + # custom model checkpoint for saving and + # tracing best model at end of traiing + # that will be used for downstream export + - class_path: train.callbacks.ModelCheckpoint + init_args: + monitor: "valid_auroc" + mode: "max" + save_top_k: 10 + save_last: true + auto_insert_metric_name: false + - class_path: train.callbacks.SaveAugmentedBatch + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "step" + # uncomment below if you want to profile + # profiler: + # class_path: lightning.pytorch.profilers.PyTorchProfiler + # dict_kwargs: + # profile_memory: true + # devices: + # strategy: set to ddp if len(devices) > 1 + #precision: 16-mixed + accelerator: auto + max_epochs: 1500 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_progress_bar: true + benchmark: true diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index 8adbbc07e..97556ca43 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.7.7", + "ml4gw>=0.7.13", "aframe", "ledger", "priors", diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index 3953dbc81..a86207ab0 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -57,6 +57,16 @@ def add_arguments_to_parser(self, parser): except Exception: pass + # TODO: This is a workaround for linking num_chirp_masses between + # the model and the architecture for in_channels. + try: + parser.link_arguments( + "data.init_args.num_chirp_masses", + "model.init_args.arch.init_args.num_chirp_masses", + ) + except Exception: + pass + parser.link_arguments( "data.init_args.sample_rate", "data.init_args.waveform_sampler.init_args.sample_rate", diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index a277db1f1..ff3ee5da0 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -5,4 +5,7 @@ ) from .multimodal import MultiModalSupervisedAframeDataset from .supervised import SupervisedAframeDataset -from .time_domain import TimeDomainSupervisedAframeDataset +from .time_domain import ( + TimeDomainSupervisedAframeDataset, + HeterodyneTimeDomainSupervisedAframeDataset, +) diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index a26806614..4f5e9c443 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -1,7 +1,9 @@ +import math import torch +from typing import Literal from train.data.supervised.supervised import SupervisedAframeDataset - +from ml4gw.transforms import Heterodyne class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): def build_val_batches(self, background, signals): @@ -20,3 +22,112 @@ def inject(self, X, waveforms=None): X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) return X, y + + +class HeterodyneTimeDomainSupervisedAframeDataset(SupervisedAframeDataset): + """ + A derived class from BaseAframeDataset and SupervisedAframeDataset, it + applies heterodyning to strain data and returns heterodyned timeseries + for loading data to train Aframe models. If `keep_last_n_seconds` is passed, + returns only the final portion of the heterodyned strain. + + Args: + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If > 0, only keep the last `n` seconds of the kernel_length. If 0, + keep the full kernel_length. + """ + + def __init__( + self, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = 0.0, + *args, + **kwargs): + super().__init__(*args, **kwargs) + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + self.keep_last_n_samples = int( + keep_last_n_seconds * self.hparams.sample_rate + ) + + def build_transforms(self, *args, **kwargs): + super().build_transforms(*args, **kwargs) + self.heterodyne_transform = Heterodyne( + sample_rate=int(self.hparams.sample_rate), + kernel_length=int(self.hparams.kernel_length), + chirp_mass=self.chirp_mass_grid, + return_type="time" + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def build_val_batches(self, background, signals): + X_bg, X_inj, psds = super().build_val_batches(background, signals) + X_bg = self.whitener(X_bg, psds) + X_bg = self.heterodyne_transform(X_bg) + _B_bg, _C_bg, _M_bg, _T_bg = X_bg.shape + X_bg = X_bg.view(_B_bg, _C_bg*_M_bg, _T_bg) + # whiten each view of injections + X_fg = [] + for inj in X_inj: + inj = self.whitener(inj, psds) + inj = self.heterodyne_transform(inj) + X_fg.append(inj) + X_fg = torch.stack(X_fg) + _V_fg, _B_fg, _C_fg, _M_fg, _T_fg = X_fg.shape + X_fg = X_fg.view(_V_fg, _B_fg, _C_fg*_M_fg, _T_fg) + + if self.keep_last_n_samples > 0: + return X_bg[..., -self.keep_last_n_samples:], X_fg[..., -self.keep_last_n_samples:] + else: + return X_bg, X_fg + + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) + X = self.whitener(X, psds) + X = self.heterodyne_transform(X) + _B, _C, _M, _T = X.shape + X = X.view(_B, _C*_M, _T) + + if self.keep_last_n_samples > 0: + return X[..., -self.keep_last_n_samples:], y + else: + return X, y diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 906633b6b..46a97b18a 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -459,6 +459,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2319,7 +2321,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.7.13" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2328,9 +2330,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/af/bab88ca4f54386735a64502a990f1bc4edb6a0353aa2b910efd9aa244919/ml4gw-0.7.13.tar.gz", hash = "sha256:4e6264fcdb9cbf5ed6a83910a231a946770687bbc6c576f8e6dc811af7102f24", size = 121893, upload-time = "2026-04-07T00:22:25.566Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/1499d3835e9fb9adeeb8b5b5107dcdee79c37d6a61cd9db912fe59100efa/ml4gw-0.7.13-py3-none-any.whl", hash = "sha256:53132028674bb079759f97b664ba8d1ab1e53ecb21a50660b9015ae7c518bde0", size = 132944, upload-time = "2026-04-07T00:22:24.325Z" }, ] [[package]] @@ -4370,7 +4372,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.7.13" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, From d515af6cfd0f4ca45c2605f159b3304f99f803bb Mon Sep 17 00:00:00 2001 From: Bhavya Gupta Date: Mon, 13 Apr 2026 09:45:26 -0500 Subject: [PATCH 2/5] make pre-commit check changes --- .../architectures/architectures/supervised.py | 6 ++- libs/utils/utils/preprocessing.py | 42 ++++++++-------- projects/train/train/cli.py | 4 +- .../train/data/supervised/time_domain.py | 50 ++++++++++--------- 4 files changed, 54 insertions(+), 48 deletions(-) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 2e28e564e..991871830 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -287,10 +287,11 @@ class SupervisedHeterodyneTimeDomainResNet(SupervisedArchitecture): Time Domain ResNet that processes a Heterodyned timeseries. Args: - num_chirp_masses (int): + num_chirp_masses (int): Number of chirp masses used to define the input channel dimension (in_channels = num_ifos x num_chirp_masses). """ + def __init__( self, num_ifos: int, @@ -306,7 +307,7 @@ def __init__( ) -> None: super().__init__() self.time_domain_resnet = ResNet1D( - in_channels=num_ifos*num_chirp_masses, + in_channels=num_ifos * num_chirp_masses, layers=layers, classes=1, kernel_size=kernel_size, @@ -316,5 +317,6 @@ def __init__( stride_type=stride_type, norm_layer=norm_layer, ) + def forward(self, X): return self.time_domain_resnet(X) diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index 2fd49ea4f..ddd7a2ee3 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -611,16 +611,16 @@ class HeterodyneTimeDomainPreprocessor(torch.nn.Module): batch_size (int): Number of kernels to extract from input. fduration (float): Duration of the whitening filter in seconds fftlength (float): FFT length for PSD calculation in seconds. - chirp_mass_low (float): + chirp_mass_low (float): Lower bound of chirp mass range (in solar masses). - chirp_mass_high (float): + chirp_mass_high (float): Upper bound of chirp mass range (in solar masses). - num_chirp_masses (int): + num_chirp_masses (int): Number of chirp mass samples to generate. - chirp_mass_spacing (Literal["linear", "log"]): + chirp_mass_spacing (Literal["linear", "log"]): Spacing of chirp mass grid. Use "linear" for evenly spaced values or "log" for logarithmic spacing. - keep_last_n_seconds (float): + keep_last_n_seconds (float): If > 0, only keep the last `n` seconds of the kernel_length. If 0, keep the full kernel_length. highpass (float, optional): Highpass frequency in Hz. Applied during @@ -637,7 +637,8 @@ class HeterodyneTimeDomainPreprocessor(torch.nn.Module): ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, ... ) >>> x = torch.randn(2, 16384) # (channels, time) - >>> X = preprocessor(x) # shape: (batch_size, channels x num_chirp_masses, kernel_size) + >>> X = preprocessor(x) + >>> # X: (batch_size, channels x num_chirp_masses, kernel_size) """ def __init__( @@ -648,10 +649,10 @@ def __init__( batch_size: int, fduration: float, fftlength: float, - chirp_mass_low: float = 1.0, - chirp_mass_high: float = 2.5, - num_chirp_masses: int = 100, - chirp_mass_spacing: Literal["linear", "log"] = "log", + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", keep_last_n_seconds: float = 0.0, highpass: float | None = None, lowpass: float | None = None, @@ -688,17 +689,15 @@ def __init__( chirp_mass_spacing, ) - self.keep_last_n_samples = int( - keep_last_n_seconds * sample_rate - ) + self.keep_last_n_samples = int(keep_last_n_seconds * sample_rate) self.heterodyne_transform = Heterodyne( - sample_rate=int(sample_rate), + sample_rate=int(sample_rate), kernel_length=int(kernel_length), chirp_mass=self.chirp_mass_grid, - return_type="time" + return_type="time", ) - + def _create_chirp_mass_grid( self, chirp_mass_low: float, @@ -763,12 +762,13 @@ def forward(self, x: Tensor) -> Tensor: x = x.reshape(-1, num_channels, self.kernel_size) # Heterodyne the whitened timeseries x = self.heterodyne_transform(x) - # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) to - # (batch_size, channels x num_chirp_mass, kernel_size) + # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) + # to (batch_size, channels x num_chirp_mass, kernel_size) _B, _C, _M, _T = x.shape - x = x.reshape(_B, _C*_M, _T) - # Returning the desired length of heterodyned strain in the time dimension + x = x.reshape(_B, _C * _M, _T) + # Returning the desired length of heterodyned strain in the + # time dimension if self.keep_last_n_samples > 0: - return x[..., -self.keep_last_n_samples:] + return x[..., -self.keep_last_n_samples :] else: return x diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index a86207ab0..e0dfd1b2e 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -61,8 +61,8 @@ def add_arguments_to_parser(self, parser): # the model and the architecture for in_channels. try: parser.link_arguments( - "data.init_args.num_chirp_masses", - "model.init_args.arch.init_args.num_chirp_masses", + "data.init_args.num_chirp_masses", + "model.init_args.arch.init_args.num_chirp_masses", ) except Exception: pass diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index 4f5e9c443..10a876605 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -5,6 +5,7 @@ from train.data.supervised.supervised import SupervisedAframeDataset from ml4gw.transforms import Heterodyne + class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): def build_val_batches(self, background, signals): X_bg, X_inj, psds = super().build_val_batches(background, signals) @@ -28,33 +29,34 @@ class HeterodyneTimeDomainSupervisedAframeDataset(SupervisedAframeDataset): """ A derived class from BaseAframeDataset and SupervisedAframeDataset, it applies heterodyning to strain data and returns heterodyned timeseries - for loading data to train Aframe models. If `keep_last_n_seconds` is passed, - returns only the final portion of the heterodyned strain. + for loading data to train Aframe models. If `keep_last_n_seconds` is + passed, returns only the final portion of the heterodyned strain. Args: - chirp_mass_low (float): + chirp_mass_low (float): Lower bound of chirp mass range (in solar masses). - chirp_mass_high (float): + chirp_mass_high (float): Upper bound of chirp mass range (in solar masses). - num_chirp_masses (int): + num_chirp_masses (int): Number of chirp mass samples to generate. - chirp_mass_spacing (Literal["linear", "log"]): + chirp_mass_spacing (Literal["linear", "log"]): Spacing of chirp mass grid. Use "linear" for evenly spaced values or "log" for logarithmic spacing. - keep_last_n_seconds (float): + keep_last_n_seconds (float): If > 0, only keep the last `n` seconds of the kernel_length. If 0, keep the full kernel_length. """ def __init__( - self, - chirp_mass_low: float = 1.0, - chirp_mass_high: float = 2.5, - num_chirp_masses: int = 100, - chirp_mass_spacing: Literal["linear", "log"] = "log", - keep_last_n_seconds: float = 0.0, - *args, - **kwargs): + self, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = 0.0, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.chirp_mass_grid = self._create_chirp_mass_grid( @@ -71,10 +73,10 @@ def __init__( def build_transforms(self, *args, **kwargs): super().build_transforms(*args, **kwargs) self.heterodyne_transform = Heterodyne( - sample_rate=int(self.hparams.sample_rate), + sample_rate=int(self.hparams.sample_rate), kernel_length=int(self.hparams.kernel_length), chirp_mass=self.chirp_mass_grid, - return_type="time" + return_type="time", ) def _create_chirp_mass_grid( @@ -98,13 +100,13 @@ def _create_chirp_mass_grid( raise ValueError( f"Invalid chirp mass spacing: {chirp_mass_spacing}" ) - + def build_val_batches(self, background, signals): X_bg, X_inj, psds = super().build_val_batches(background, signals) X_bg = self.whitener(X_bg, psds) X_bg = self.heterodyne_transform(X_bg) _B_bg, _C_bg, _M_bg, _T_bg = X_bg.shape - X_bg = X_bg.view(_B_bg, _C_bg*_M_bg, _T_bg) + X_bg = X_bg.view(_B_bg, _C_bg * _M_bg, _T_bg) # whiten each view of injections X_fg = [] for inj in X_inj: @@ -113,10 +115,12 @@ def build_val_batches(self, background, signals): X_fg.append(inj) X_fg = torch.stack(X_fg) _V_fg, _B_fg, _C_fg, _M_fg, _T_fg = X_fg.shape - X_fg = X_fg.view(_V_fg, _B_fg, _C_fg*_M_fg, _T_fg) + X_fg = X_fg.view(_V_fg, _B_fg, _C_fg * _M_fg, _T_fg) if self.keep_last_n_samples > 0: - return X_bg[..., -self.keep_last_n_samples:], X_fg[..., -self.keep_last_n_samples:] + return X_bg[..., -self.keep_last_n_samples :], X_fg[ + ..., -self.keep_last_n_samples : + ] else: return X_bg, X_fg @@ -125,9 +129,9 @@ def inject(self, X, waveforms=None): X = self.whitener(X, psds) X = self.heterodyne_transform(X) _B, _C, _M, _T = X.shape - X = X.view(_B, _C*_M, _T) + X = X.view(_B, _C * _M, _T) if self.keep_last_n_samples > 0: - return X[..., -self.keep_last_n_samples:], y + return X[..., -self.keep_last_n_samples :], y else: return X, y From cef39ceff5744dc3f6e58dc3e82b0d929f297c1e Mon Sep 17 00:00:00 2001 From: Bhavya Gupta Date: Fri, 17 Apr 2026 13:38:46 -0500 Subject: [PATCH 3/5] update changes from feedback and add preprocessing tests --- libs/utils/pyproject.toml | 2 +- libs/utils/tests/test_preprocessing.py | 55 ++++++ libs/utils/utils/augmentation.py | 135 +++++++++++++ libs/utils/utils/preprocessing.py | 183 ------------------ libs/utils/uv.lock | 8 +- projects/export/export_heterodyne.yaml | 21 +- projects/export/pyproject.toml | 2 +- projects/train/configs/time_heterodyne.yaml | 2 +- projects/train/pyproject.toml | 2 +- .../train/data/supervised/time_domain.py | 23 ++- uv.lock | 10 +- 11 files changed, 229 insertions(+), 214 deletions(-) create mode 100644 libs/utils/utils/augmentation.py diff --git a/libs/utils/pyproject.toml b/libs/utils/pyproject.toml index dcfdde2c2..70dbe8069 100644 --- a/libs/utils/pyproject.toml +++ b/libs/utils/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "h5py~=3.6", "numpy>=1.26.4,<2", "s3fs>=2024,<2025", - "ml4gw>=0.7.10", + "ml4gw>=0.8.0", "astropy>=6.0.1", ] diff --git a/libs/utils/tests/test_preprocessing.py b/libs/utils/tests/test_preprocessing.py index 1a85872f2..9174f1e3e 100644 --- a/libs/utils/tests/test_preprocessing.py +++ b/libs/utils/tests/test_preprocessing.py @@ -5,6 +5,8 @@ MultiModalPreprocessor, TimeSpectrogramPreprocessor, ) +from utils.augmentation import HeterodyneAugmentor +from ml4gw.transforms import Heterodyne import torch import pytest @@ -247,6 +249,59 @@ def simple_augmentor(x): kernels_aug = whitener_aug(x) assert torch.allclose(kernels_aug, kernels * 2) + def test_with_heterodyne_augmentor(self): + heterodyne_augmentor = HeterodyneAugmentor( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass_low=1.0, + chirp_mass_high=2.5, + num_chirp_masses=10, + chirp_mass_spacing="log", + keep_last_n_seconds=None, + ) + + whitener = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + ) + + whitener_aug = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + augmentor=heterodyne_augmentor, + ) + + channels = 2 + total_samples = int( + ( + (self.batch_size - 1) * whitener.stride_size + + whitener.kernel_size + ) + * 2 + ) + x = torch.randn(channels, total_samples) + + heterodyne = Heterodyne( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass=heterodyne_augmentor.chirp_mass_grid, + return_type="time", + ) + + kernels_heterodyned = heterodyne(whitener(x)) + _B, _C, _M, _T = kernels_heterodyned.shape + kernels = kernels_heterodyned.reshape(_B, _C * _M, _T) + kernels_aug = whitener_aug(x) + assert torch.allclose(kernels_aug, kernels) + class TestMultiModalPreprocessor: """Test suite for MultiModalPreprocessor module.""" diff --git a/libs/utils/utils/augmentation.py b/libs/utils/utils/augmentation.py new file mode 100644 index 000000000..7680b7a0e --- /dev/null +++ b/libs/utils/utils/augmentation.py @@ -0,0 +1,135 @@ +import math +import torch +from torch import Tensor +from typing import Literal + +from ml4gw.transforms import Heterodyne + + +class HeterodyneAugmentor(torch.nn.Module): + """ + Apply a heterodyne transform over a grid of chirp masses to a batch + of time-series data. + + Args: + sample_rate (float): Input sampling rate in Hz. + kernel_length (float): Length of output kernels in seconds. + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. + Shape: + Input: (batch_size, channels, time) + Output: (batch_size, channels * num_chirp_masses, time_out) + + where `time_out = time` unless `keep_last_n_seconds` is set. + + Note: + The output shape of the `BatchWhitener` is + (batch_size, channels, kernel_size). + The `HeterodyneAugmentor` changes the output shape to + (batch_size, channels * num_chirp_masses, time_out) + where time_out is determined by the `keep_last_n_seconds` parameter. + + Example: + >>> augmentor = HeterodyneAugmentor( + ... sample_rate=2048, kernel_length=8, chirp_mass_low=1.0, + ... chirp_mass_high=2.5, num_chirp_masses=100, + ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, + ... ) + >>> x = torch.randn(8, 2, 16384) # (batch, channels, time) + >>> y = augmentor(x) + >>> # y: (8, 2 * 100, 8192) since we keep the last 4 seconds at 2048 Hz + """ + + def __init__( + self, + sample_rate: float, + kernel_length: float, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = None, + ): + super().__init__() + self.sample_rate = sample_rate + self.kernel_length = kernel_length + self.keep_last_n_seconds = keep_last_n_seconds + self.num_chirp_masses = num_chirp_masses + self.keep_last_n_seconds = keep_last_n_seconds + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * sample_rate + ) + + self.heterodyne_transform = Heterodyne( + sample_rate=sample_rate, + kernel_length=kernel_length, + chirp_mass=self.chirp_mass_grid, + return_type="time", + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input data of shape (batch, channels, time). + + Returns: + Tensor: Output data of shape + (batch, channels * num_chirp_masses, time_out), + where `time_out` is either the same as input time + or determined by `keep_last_n_seconds`. + """ + _B, _C, _T = x.shape + x_heterodyned = torch.empty((_B, _C * self.num_chirp_masses, _T)) + # Heterodyne the whitened timeseries + x = self.heterodyne_transform(x) + # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) + # to (batch_size, channels x num_chirp_mass, kernel_size) + x = x.reshape(_B, _C * self.num_chirp_masses, _T) + x_heterodyned[:, :, :] = x + # Returning the desired length of heterodyned strain in the + # time dimension + if self.keep_last_n_seconds is not None: + return x_heterodyned[..., -self.keep_last_n_samples :] + else: + return x_heterodyned diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index ddd7a2ee3..7bc4bb78f 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -1,16 +1,13 @@ from collections.abc import Callable -import math import torch from torch import Tensor -from typing import Literal from ml4gw.transforms import ( SpectralDensity, Whiten, Decimator, SingleQTransform, - Heterodyne, ) from ml4gw.utils.slicing import unfold_windows @@ -592,183 +589,3 @@ def forward(self, x: Tensor) -> Tensor: # first input is timeseries and second input is spectrogram return x[1], spec - - -class HeterodyneTimeDomainPreprocessor(torch.nn.Module): - """ - Calculate the PSDs and whiten an entire batch of kernels at once - and then heterodyne the timeseries. - - Applies heterodyne transform to the desired `kernel_length` of the - strain. If `keep_last_n_seconds` is passed, returns only the final - portion of the heterodyned strain. - - Args: - kernel_length (float): Length of output kernels in seconds. - sample_rate (float): Input sampling rate in Hz. - inference_sampling_rate (float): Sampling rate of network output in Hz. - Determines the overlap between kernels. - batch_size (int): Number of kernels to extract from input. - fduration (float): Duration of the whitening filter in seconds - fftlength (float): FFT length for PSD calculation in seconds. - chirp_mass_low (float): - Lower bound of chirp mass range (in solar masses). - chirp_mass_high (float): - Upper bound of chirp mass range (in solar masses). - num_chirp_masses (int): - Number of chirp mass samples to generate. - chirp_mass_spacing (Literal["linear", "log"]): - Spacing of chirp mass grid. Use "linear" for evenly spaced - values or "log" for logarithmic spacing. - keep_last_n_seconds (float): - If > 0, only keep the last `n` seconds of the kernel_length. If 0, - keep the full kernel_length. - highpass (float, optional): Highpass frequency in Hz. Applied during - whitening. Defaults to None. - lowpass (float, optional): Lowpass frequency in Hz. Applied during - whitening. Defaults to None. - - Example: - >>> preprocessor = HeterodyneTimeDomainPreprocessor( - ... kernel_length=8, sample_rate=2048, - ... inference_sampling_rate=16, batch_size=128, - ... fduration=2, fftlength=2, chirp_mass_low=1.0, - ... chirp_mass_high=2.5, num_chirp_masses=100, - ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, - ... ) - >>> x = torch.randn(2, 16384) # (channels, time) - >>> X = preprocessor(x) - >>> # X: (batch_size, channels x num_chirp_masses, kernel_size) - """ - - def __init__( - self, - kernel_length: float, - sample_rate: float, - inference_sampling_rate: float, - batch_size: int, - fduration: float, - fftlength: float, - chirp_mass_low: float = 1.0, - chirp_mass_high: float = 2.5, - num_chirp_masses: int = 100, - chirp_mass_spacing: Literal["linear", "log"] = "log", - keep_last_n_seconds: float = 0.0, - highpass: float | None = None, - lowpass: float | None = None, - ) -> None: - super().__init__() - # Calculate stride between kernels based on inference sampling rate - self.stride_size = int(sample_rate / inference_sampling_rate) - # Convert kernel length to samples - self.kernel_size = int(kernel_length * sample_rate) - - # do length calculations in units of samples, - # then convert back to length to guard for intification - strides = (batch_size - 1) * self.stride_size - fsize = int(fduration * sample_rate) - size = strides + self.kernel_size + fsize - length = size / sample_rate - - # Initialize PSD estimator with calculated total length - self.psd_estimator = PsdEstimator( - length, - sample_rate, - fftlength=fftlength, - overlap=None, - average="median", - fast=highpass is not None, - ) - # Initialize whitening module - self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) - - self.chirp_mass_grid = self._create_chirp_mass_grid( - chirp_mass_low, - chirp_mass_high, - num_chirp_masses, - chirp_mass_spacing, - ) - - self.keep_last_n_samples = int(keep_last_n_seconds * sample_rate) - - self.heterodyne_transform = Heterodyne( - sample_rate=int(sample_rate), - kernel_length=int(kernel_length), - chirp_mass=self.chirp_mass_grid, - return_type="time", - ) - - def _create_chirp_mass_grid( - self, - chirp_mass_low: float, - chirp_mass_high: float, - num_chirp_masses: int, - chirp_mass_spacing: Literal["linear", "log"], - ) -> torch.Tensor: - if chirp_mass_spacing == "linear": - return torch.linspace( - chirp_mass_low, chirp_mass_high, num_chirp_masses - ) - elif chirp_mass_spacing == "log": - return torch.logspace( - math.log10(chirp_mass_low), - math.log10(chirp_mass_high), - num_chirp_masses, - ) - else: - raise ValueError( - f"Invalid chirp mass spacing: {chirp_mass_spacing}" - ) - - def forward(self, x: Tensor) -> Tensor: - """ - Estimate PSD, whiten data, and unfold kernels. - - Args: - x (Tensor): Input data of shape (batch, channels, time) or - (channels, time). - - Returns: - Tensor: Extracted and optionally augmented kernels of shape - (batch_size, channels, kernel_size). - - If return_whitened=True, returns tuple of (kernels, whitened_data). - - Raises: - ValueError: If input is not 2 or 3 dimensional. - """ - # Determine number of channels for later reshaping - if x.ndim == 3: - num_channels = x.size(1) - elif x.ndim == 2: - num_channels = x.size(0) - else: - raise ValueError( - "Expected input to be either 2 or 3 dimensional, " - "but found shape {}".format(x.shape) - ) - - # Estimate PSD and prepare data - x, psd = self.psd_estimator(x.double()) - # Apply whitening using estimated PSD - whitened = self.whitener(x, psd) - - # unfold x and then put it into the expected shape. - # Note that if x has both signal and background - # batch elements, they will be interleaved along - # the batch dimension after unfolding - x = unfold_windows(whitened, self.kernel_size, self.stride_size) - # Reshape to (batch_size, channels, kernel_size) - x = x.reshape(-1, num_channels, self.kernel_size) - # Heterodyne the whitened timeseries - x = self.heterodyne_transform(x) - # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) - # to (batch_size, channels x num_chirp_mass, kernel_size) - _B, _C, _M, _T = x.shape - x = x.reshape(_B, _C * _M, _T) - # Returning the desired length of heterodyned strain in the - # time dimension - if self.keep_last_n_samples > 0: - return x[..., -self.keep_last_n_samples :] - else: - return x diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index 5d5bc6939..1cfd3e70d 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -456,7 +456,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -465,9 +465,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1124,7 +1124,7 @@ dev = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/export/export_heterodyne.yaml b/projects/export/export_heterodyne.yaml index f51b11810..1e3414634 100644 --- a/projects/export/export_heterodyne.yaml +++ b/projects/export/export_heterodyne.yaml @@ -9,26 +9,31 @@ weights: batch_file: repository_directory: num_ifos: 2 -kernel_length: 8.0 +kernel_length: 4.0 inference_sampling_rate: 4 sample_rate: 2048 batch_size: 128 fduration: 2 psd_length: 64 preprocessor: - class_path: utils.preprocessing.HeterodyneTimeDomainPreprocessor + class_path: utils.preprocessing.BatchWhitener init_args: - kernel_length: 8.0 + kernel_length: 4.0 sample_rate: 2048 inference_sampling_rate: 4 batch_size: 128 fduration: 2 fftlength: 2 - chirp_mass_low: 1.0 - chirp_mass_high: 2.5 - num_chirp_masses: 100 - chirp_mass_spacing: "log" - keep_last_n_seconds: 4.0 + augmentor: + class_path: utils.augmentation.HeterodyneAugmentor + init_args: + sample_rate: 2048 + kernel_length: 4.0 + chirp_mass_low: 1.0 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: null highpass: 32.0 streams_per_gpu: 12 # num_outputs: diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index 609e4e4fe..b95355d18 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ - "ml4gw>=0.7.13", + "ml4gw>=0.8.0", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]>=0.2.1", diff --git a/projects/train/configs/time_heterodyne.yaml b/projects/train/configs/time_heterodyne.yaml index add60628b..d2928f781 100644 --- a/projects/train/configs/time_heterodyne.yaml +++ b/projects/train/configs/time_heterodyne.yaml @@ -57,7 +57,7 @@ data: chirp_mass_high: 2.5 num_chirp_masses: 100 chirp_mass_spacing: "log" - keep_last_n_seconds: 4 + keep_last_n_seconds: null # highpass: # lowpass: fftlength: 2 diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index 97556ca43..d81f1834a 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.7.13", + "ml4gw>=0.8.0", "aframe", "ledger", "priors", diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index 10a876605..c31bbbafe 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -43,8 +43,8 @@ class HeterodyneTimeDomainSupervisedAframeDataset(SupervisedAframeDataset): Spacing of chirp mass grid. Use "linear" for evenly spaced values or "log" for logarithmic spacing. keep_last_n_seconds (float): - If > 0, only keep the last `n` seconds of the kernel_length. If 0, - keep the full kernel_length. + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. """ def __init__( @@ -53,7 +53,7 @@ def __init__( chirp_mass_high: float = 2.5, num_chirp_masses: int = 100, chirp_mass_spacing: Literal["linear", "log"] = "log", - keep_last_n_seconds: float = 0.0, + keep_last_n_seconds: float = None, *args, **kwargs, ): @@ -66,15 +66,18 @@ def __init__( chirp_mass_spacing, ) - self.keep_last_n_samples = int( - keep_last_n_seconds * self.hparams.sample_rate - ) + self.keep_last_n_seconds = keep_last_n_seconds + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * self.hparams.sample_rate + ) def build_transforms(self, *args, **kwargs): super().build_transforms(*args, **kwargs) self.heterodyne_transform = Heterodyne( - sample_rate=int(self.hparams.sample_rate), - kernel_length=int(self.hparams.kernel_length), + sample_rate=self.hparams.sample_rate, + kernel_length=self.hparams.kernel_length, chirp_mass=self.chirp_mass_grid, return_type="time", ) @@ -117,7 +120,7 @@ def build_val_batches(self, background, signals): _V_fg, _B_fg, _C_fg, _M_fg, _T_fg = X_fg.shape X_fg = X_fg.view(_V_fg, _B_fg, _C_fg * _M_fg, _T_fg) - if self.keep_last_n_samples > 0: + if self.keep_last_n_seconds is not None: return X_bg[..., -self.keep_last_n_samples :], X_fg[ ..., -self.keep_last_n_samples : ] @@ -131,7 +134,7 @@ def inject(self, X, waveforms=None): _B, _C, _M, _T = X.shape X = X.view(_B, _C * _M, _T) - if self.keep_last_n_samples > 0: + if self.keep_last_n_seconds is not None: return X[..., -self.keep_last_n_samples :], y else: return X, y diff --git a/uv.lock b/uv.lock index ec36dea15..b3700fae0 100644 --- a/uv.lock +++ b/uv.lock @@ -1299,7 +1299,7 @@ name = "importlib-metadata" version = "8.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp" }, + { name = "zipp", marker = "python_full_version < '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/33/08/c1395a292bb23fd03bdf572a1357c5a733d3eecbab877641ceacab23db6e/importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580", size = 55767, upload-time = "2025-01-20T22:21:30.429Z" } wheels = [ @@ -1911,7 +1911,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1920,9 +1920,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -3688,7 +3688,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From 177bf7401e4fed77b34b4ee54209dcd8da0a2c9d Mon Sep 17 00:00:00 2001 From: Bhavya Gupta Date: Fri, 17 Apr 2026 14:26:12 -0500 Subject: [PATCH 4/5] fix pyproject.toml --- projects/export/uv.lock | 10 +++++----- projects/train/uv.lock | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/projects/export/uv.lock b/projects/export/uv.lock index d1b58c1a0..a57e9deec 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.7.13" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, { name = "nvidia-cudnn-cu11", specifier = "==8.9.6.50" }, { name = "tensorrt", specifier = "==8.5.2.2" }, @@ -972,7 +972,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.13" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -981,9 +981,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/af/bab88ca4f54386735a64502a990f1bc4edb6a0353aa2b910efd9aa244919/ml4gw-0.7.13.tar.gz", hash = "sha256:4e6264fcdb9cbf5ed6a83910a231a946770687bbc6c576f8e6dc811af7102f24", size = 121893, upload-time = "2026-04-07T00:22:25.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/7e/1499d3835e9fb9adeeb8b5b5107dcdee79c37d6a61cd9db912fe59100efa/ml4gw-0.7.13-py3-none-any.whl", hash = "sha256:53132028674bb079759f97b664ba8d1ab1e53ecb21a50660b9015ae7c518bde0", size = 132944, upload-time = "2026-04-07T00:22:24.325Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1980,7 +1980,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 46a97b18a..c8d33f285 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -2321,7 +2321,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.13" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2330,9 +2330,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/af/bab88ca4f54386735a64502a990f1bc4edb6a0353aa2b910efd9aa244919/ml4gw-0.7.13.tar.gz", hash = "sha256:4e6264fcdb9cbf5ed6a83910a231a946770687bbc6c576f8e6dc811af7102f24", size = 121893, upload-time = "2026-04-07T00:22:25.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/7e/1499d3835e9fb9adeeb8b5b5107dcdee79c37d6a61cd9db912fe59100efa/ml4gw-0.7.13-py3-none-any.whl", hash = "sha256:53132028674bb079759f97b664ba8d1ab1e53ecb21a50660b9015ae7c518bde0", size = 132944, upload-time = "2026-04-07T00:22:24.325Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -4372,7 +4372,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.13" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -4511,7 +4511,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] From 2475846d4ab8faeca7443a71818e7ab26a0a50f4 Mon Sep 17 00:00:00 2001 From: Bhavya Gupta Date: Fri, 17 Apr 2026 14:34:05 -0500 Subject: [PATCH 5/5] fix uv.lock files --- projects/data/uv.lock | 10 ++++++---- projects/online/uv.lock | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/projects/data/uv.lock b/projects/data/uv.lock index d456db315..4008a1c04 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -333,6 +333,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -1494,7 +1496,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1503,9 +1505,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -2706,7 +2708,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/uv.lock b/projects/online/uv.lock index f97583ee8..4eaab91bb 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -478,6 +478,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2276,7 +2278,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2285,9 +2287,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -4432,7 +4434,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ]