diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 7ff6db14c..350160f02 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,4 +6,5 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, SupervisedTimeSpectrogramResNet, + SupervisedHeterodyneTimeDomainResNet, ) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 08fc950e1..a0ae7e539 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -281,3 +281,43 @@ def forward(self, X, X_spec): time_domain_output = self.time_domain_resnet(X) spec_domain_output = self.spectrogram_resnet(X_spec) return time_domain_output, spec_domain_output + + +class SupervisedHeterodyneTimeDomainResNet(SupervisedArchitecture): + """ + Time Domain ResNet that processes a Heterodyned timeseries. + + Args: + num_chirp_masses (int): + Number of chirp masses used to define the input channel + dimension (in_channels = num_ifos x num_chirp_masses). + """ + + def __init__( + self, + num_ifos: int, + num_chirp_masses: int, + layers: list[int], + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ) -> None: + super().__init__() + self.time_domain_resnet = ResNet1D( + in_channels=num_ifos * num_chirp_masses, + layers=layers, + classes=1, + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + def forward(self, X): + return self.time_domain_resnet(X) diff --git a/libs/utils/pyproject.toml b/libs/utils/pyproject.toml index 9c0a76ce0..bcb4f2cc5 100644 --- a/libs/utils/pyproject.toml +++ b/libs/utils/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "h5py~=3.6", "numpy>=2", "s3fs>=2024,<2025", - "ml4gw>=0.7.10", + "ml4gw>=0.8.0", "astropy>=6.0.1", ] diff --git a/libs/utils/tests/test_preprocessing.py b/libs/utils/tests/test_preprocessing.py index ae1654dda..fb29633a6 100644 --- a/libs/utils/tests/test_preprocessing.py +++ b/libs/utils/tests/test_preprocessing.py @@ -7,6 +7,10 @@ PsdEstimator, TimeSpectrogramPreprocessor, ) +from utils.augmentation import HeterodyneAugmentor +from ml4gw.transforms import Heterodyne +import torch +import pytest class TestBackgroundSnapshotter: @@ -247,6 +251,59 @@ def simple_augmentor(x): kernels_aug = whitener_aug(x) assert torch.allclose(kernels_aug, kernels * 2) + def test_with_heterodyne_augmentor(self): + heterodyne_augmentor = HeterodyneAugmentor( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass_low=1.0, + chirp_mass_high=2.5, + num_chirp_masses=10, + chirp_mass_spacing="log", + keep_last_n_seconds=None, + ) + + whitener = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + ) + + whitener_aug = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + augmentor=heterodyne_augmentor, + ) + + channels = 2 + total_samples = int( + ( + (self.batch_size - 1) * whitener.stride_size + + whitener.kernel_size + ) + * 2 + ) + x = torch.randn(channels, total_samples) + + heterodyne = Heterodyne( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass=heterodyne_augmentor.chirp_mass_grid, + return_type="time", + ) + + kernels_heterodyned = heterodyne(whitener(x)) + _B, _C, _M, _T = kernels_heterodyned.shape + kernels = kernels_heterodyned.reshape(_B, _C * _M, _T) + kernels_aug = whitener_aug(x) + assert torch.allclose(kernels_aug, kernels) + class TestMultiModalPreprocessor: """Test suite for MultiModalPreprocessor module.""" diff --git a/libs/utils/utils/augmentation.py b/libs/utils/utils/augmentation.py new file mode 100644 index 000000000..7680b7a0e --- /dev/null +++ b/libs/utils/utils/augmentation.py @@ -0,0 +1,135 @@ +import math +import torch +from torch import Tensor +from typing import Literal + +from ml4gw.transforms import Heterodyne + + +class HeterodyneAugmentor(torch.nn.Module): + """ + Apply a heterodyne transform over a grid of chirp masses to a batch + of time-series data. + + Args: + sample_rate (float): Input sampling rate in Hz. + kernel_length (float): Length of output kernels in seconds. + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. + Shape: + Input: (batch_size, channels, time) + Output: (batch_size, channels * num_chirp_masses, time_out) + + where `time_out = time` unless `keep_last_n_seconds` is set. + + Note: + The output shape of the `BatchWhitener` is + (batch_size, channels, kernel_size). + The `HeterodyneAugmentor` changes the output shape to + (batch_size, channels * num_chirp_masses, time_out) + where time_out is determined by the `keep_last_n_seconds` parameter. + + Example: + >>> augmentor = HeterodyneAugmentor( + ... sample_rate=2048, kernel_length=8, chirp_mass_low=1.0, + ... chirp_mass_high=2.5, num_chirp_masses=100, + ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, + ... ) + >>> x = torch.randn(8, 2, 16384) # (batch, channels, time) + >>> y = augmentor(x) + >>> # y: (8, 2 * 100, 8192) since we keep the last 4 seconds at 2048 Hz + """ + + def __init__( + self, + sample_rate: float, + kernel_length: float, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = None, + ): + super().__init__() + self.sample_rate = sample_rate + self.kernel_length = kernel_length + self.keep_last_n_seconds = keep_last_n_seconds + self.num_chirp_masses = num_chirp_masses + self.keep_last_n_seconds = keep_last_n_seconds + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * sample_rate + ) + + self.heterodyne_transform = Heterodyne( + sample_rate=sample_rate, + kernel_length=kernel_length, + chirp_mass=self.chirp_mass_grid, + return_type="time", + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input data of shape (batch, channels, time). + + Returns: + Tensor: Output data of shape + (batch, channels * num_chirp_masses, time_out), + where `time_out` is either the same as input time + or determined by `keep_last_n_seconds`. + """ + _B, _C, _T = x.shape + x_heterodyned = torch.empty((_B, _C * self.num_chirp_masses, _T)) + # Heterodyne the whitened timeseries + x = self.heterodyne_transform(x) + # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) + # to (batch_size, channels x num_chirp_mass, kernel_size) + x = x.reshape(_B, _C * self.num_chirp_masses, _T) + x_heterodyned[:, :, :] = x + # Returning the desired length of heterodyned strain in the + # time dimension + if self.keep_last_n_seconds is not None: + return x_heterodyned[..., -self.keep_last_n_samples :] + else: + return x_heterodyned diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index 43a5e8406..a6882d09b 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -418,7 +418,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -427,9 +427,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1095,7 +1095,7 @@ dev = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index b1e64cc12..ca4d5a28c 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -1448,7 +1448,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1457,9 +1457,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -2753,7 +2753,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/export/export_heterodyne.yaml b/projects/export/export_heterodyne.yaml new file mode 100644 index 000000000..1e3414634 --- /dev/null +++ b/projects/export/export_heterodyne.yaml @@ -0,0 +1,43 @@ +# commented args (i.e. comments with colon at end +# to illustrate a value needs to be filled out) +# represent values filled out +# by the export task at run time. To build a functional +# standalone config, add these in yourself. + +logfile: +weights: +batch_file: +repository_directory: +num_ifos: 2 +kernel_length: 4.0 +inference_sampling_rate: 4 +sample_rate: 2048 +batch_size: 128 +fduration: 2 +psd_length: 64 +preprocessor: + class_path: utils.preprocessing.BatchWhitener + init_args: + kernel_length: 4.0 + sample_rate: 2048 + inference_sampling_rate: 4 + batch_size: 128 + fduration: 2 + fftlength: 2 + augmentor: + class_path: utils.augmentation.HeterodyneAugmentor + init_args: + sample_rate: 2048 + kernel_length: 4.0 + chirp_mass_low: 1.0 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: null + highpass: 32.0 +streams_per_gpu: 12 +# num_outputs: +aframe_instances: 1 +platform: TENSORRT +clean: true +verbose: true diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index 408c4266c..48bfe77d9 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.11,<3.14" license = "MIT" dependencies = [ - "ml4gw>=0.7.7", + "ml4gw>=0.8.0", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]", diff --git a/projects/export/uv.lock b/projects/export/uv.lock index d408303d0..db182e34f 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -357,7 +357,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "ml4gw-hermes", extras = ["torch"], git = "https://github.com/ML4GW/hermes?branch=dev" }, { name = "nvidia-cudnn-cu12" }, { name = "tensorrt", specifier = ">=10.0" }, @@ -732,7 +732,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -741,9 +741,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1692,7 +1692,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/online/main.py b/projects/online/online/main.py index a21a14d75..e8fe56ba4 100644 --- a/projects/online/online/main.py +++ b/projects/online/online/main.py @@ -1,6 +1,7 @@ import atexit import logging import signal +from math import floor import traceback from collections.abc import Iterable from pathlib import Path @@ -55,6 +56,7 @@ def load_model(model: Architecture, weights: Path): arch_weights = { k.removeprefix("model."): v for k, v in checkpoint["state_dict"].items() + if k.startswith("model.") } model.load_state_dict(arch_weights) model.eval() @@ -64,8 +66,9 @@ def load_model(model: Architecture, weights: Path): def load_amplfi(model: FlowArchitecture, weights: Path, num_params: int): model, checkpoint = load_model(model, weights) scaler_weights = { - k.removeprefix("scalar."): v + k.removeprefix("scaler."): v for k, v in checkpoint["state_dict"].items() + if k.startswith("scaler.") } scaler = ChannelWiseScaler(num_params) scaler.load_state_dict(scaler_weights) @@ -242,15 +245,15 @@ def search( # but don't search for events if X is not None: logging.debug( - f"Frame {t0} is not analysis ready. Using dummy values " - "for inference and ignoring any triggers" + f"Frame {floor(t0)} is not analysis ready. Using dummy " + "values for inference and ignoring any triggers" ) pass # or if it's because frames were dropped within the stream # in which case we should reset our states else: logging.warning( - f"Missing frame files after timestep {t0}, " + f"Missing frame files after timestep {floor(t0)}, " "resetting states" ) @@ -264,7 +267,7 @@ def search( elif not in_spec: # the frame is analysis ready, but previous frames # weren't, so reset our running states - logging.info(f"Frame {t0} is ready again, resetting states") + logging.info(f"Frame {floor(t0)} is ready again, resetting states") state = snapshotter.reset() input_buffer.reset() output_buffer.reset() diff --git a/projects/online/online/monitor/main.py b/projects/online/online/monitor/main.py index 914e0e660..22104af6c 100644 --- a/projects/online/online/monitor/main.py +++ b/projects/online/online/monitor/main.py @@ -48,7 +48,7 @@ def main( detected_events = [ event for event in detected_event_dir.iterdir() - if float(event.name.split("_")[1]) + if float(event.name.split("_")[1]) > summary_page.start_time ] # The event page will be created/updated only if the event directory diff --git a/projects/online/online/subprocesses/events.py b/projects/online/online/subprocesses/events.py index c8a8001da..bec57a9a3 100644 --- a/projects/online/online/subprocesses/events.py +++ b/projects/online/online/subprocesses/events.py @@ -3,6 +3,9 @@ from queue import Queue from typing import TYPE_CHECKING +import certifi +from ligo.gracedb.kafka import GraceDbKafkaProducer + from .utils import subprocess_wrapper if TYPE_CHECKING: @@ -23,6 +26,13 @@ def event_creation_subprocess( # override with subprocesses logger gdb.logger = logger + + # Need to create the producer within the subprocess that uses it + gdb.kafka_producer = GraceDbKafkaProducer( + bootstrap_servers="kafka-dev.ligo.org:9092", + service_url=gdb.server.service_url, + ca_cert_path=certifi.where(), + ) while True: event = event_queue.get() logger.debug("Putting event in pastro queue") diff --git a/projects/online/online/utils/gdb.py b/projects/online/online/utils/gdb.py index 28eeb3bdc..44bc1c5f1 100644 --- a/projects/online/online/utils/gdb.py +++ b/projects/online/online/utils/gdb.py @@ -81,7 +81,12 @@ def __init__( **kwargs, service_url=server.service_url, use_auth="scitoken", + api_version="v2", ) + + # The kafka producer will be set in the subprocess that uses it + self.kafka_producer = None + self.server = server self.write_dir = write_dir if logger is None: @@ -99,6 +104,8 @@ def submit(self, event: Event): pipeline="aframe", filename=str(filename), search="AllSky", + kafka=self.kafka_producer, + http_fallback=True, ) self.logger.debug("Event created") diff --git a/projects/online/pyproject.toml b/projects/online/pyproject.toml index 720b2f669..2188dba2e 100644 --- a/projects/online/pyproject.toml +++ b/projects/online/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "scipy>=1.9", "matplotlib>=3.9.4", "ligo-skymap>=2.4.0,<3", - "ligo-gracedb>=2.14.1", + "ligo-gracedb[kafka]>=2.15.4", "tables>=3.9", "gwpy>=3.0.12", "jsonargparse>=4.18.0", diff --git a/projects/online/uv.lock b/projects/online/uv.lock index 2bdb1613e..fcd0d93f0 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -713,6 +713,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180, upload-time = "2024-03-12T16:53:39.226Z" }, ] +[[package]] +name = "confluent-kafka" +version = "2.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/38/f5855cae6d328fa66e689d068709f91cbbd4d72e7e03959998bd43ac6b26/confluent_kafka-2.13.2.tar.gz", hash = "sha256:619d10d1d77c9821ba913b3e42a33ade7f889f3573c7f3c17b57c3056e3310f5", size = 276068, upload-time = "2026-03-02T12:53:31.457Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/a7/7dfee75b246f5e5f0832a27e365cd9e8050591c5f4301714672bea2375ce/confluent_kafka-2.13.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e85dc2aaf08dcac610d20b24d252a24891440cf33c09396c957781b8a1f24015", size = 3629660, upload-time = "2026-03-02T12:52:42.595Z" }, + { url = "https://files.pythonhosted.org/packages/93/77/bc6bca93f455e91b41b196bb208b9cbfc517442a65abae2391f1af64cd2f/confluent_kafka-2.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:eb1b218beeaae36b3fc94927e30df5f6d662858e766eada2369b290df0b1bff0", size = 3190013, upload-time = "2026-03-02T12:52:44.193Z" }, + { url = "https://files.pythonhosted.org/packages/2e/ce/2ee04c1b2707b6dd7177eab40fced00b474671d2303e5096d96f3bf7e231/confluent_kafka-2.13.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:69b4286296504b89c0c3cd1e531d12053c633e56d2c5b477ff9000524fe24eb5", size = 3719524, upload-time = "2026-03-02T12:52:45.488Z" }, + { url = "https://files.pythonhosted.org/packages/7c/93/5c40e2f7eae52774db6b14060254d001ae8c4ef8d4385bf2f13294dd929f/confluent_kafka-2.13.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e47be4267d3feda5bf1c066f140f61e61ea28bd6ecdb60c2a52ec1a91b8903e7", size = 3976453, upload-time = "2026-03-02T12:52:47.787Z" }, + { url = "https://files.pythonhosted.org/packages/46/85/a3d25b67470abbd4835fca714a419465323ba79dceefcdda65dfa4415c80/confluent_kafka-2.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:84dd6e7f456910aa4d4763d86efa0dded7167fdf1251b51d808dec0f124f5e13", size = 4097308, upload-time = "2026-03-02T12:52:49.083Z" }, + { url = "https://files.pythonhosted.org/packages/d9/d3/a845c6993a728b8b6bdce9b500d15c3ec3663cd95d2bbf9c1b8cfd519b17/confluent_kafka-2.13.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e259c0d2b9a7e16211b45404f62869502246ac3d03e35a1f80720fd09d262457", size = 3635348, upload-time = "2026-03-02T12:52:50.927Z" }, + { url = "https://files.pythonhosted.org/packages/ab/22/1cb998f7b3ee613d5b29f4b98e4a7539776eb0819b89d7c3cdd19a685692/confluent_kafka-2.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:77ea4ceccdbb67498787b7c02cc329c32417bb730e9383f46c74eb9c5851763c", size = 3194667, upload-time = "2026-03-02T12:52:53.468Z" }, + { url = "https://files.pythonhosted.org/packages/11/38/8a1b12321068e8ae126e62600a55d7a1872f969e1de5ec7f602e0dba8394/confluent_kafka-2.13.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a64a8967734f865f54b766553d63a40f17081cd3d2c6cfe6d3217aa7494d88fb", size = 3724453, upload-time = "2026-03-02T12:52:55.187Z" }, + { url = "https://files.pythonhosted.org/packages/5c/06/3effa66c59a69e17cc48c69ae2533699f4321fac1b46741f2e4b1aefb1e7/confluent_kafka-2.13.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e4cb7d112463ec15a01a3f0e0d20392cda6e46156a6439fcaaad2267696f5cde", size = 3980919, upload-time = "2026-03-02T12:52:56.852Z" }, + { url = "https://files.pythonhosted.org/packages/98/22/f76a8b85fad652b4d5c0a0259c8f7bb66393d2d9f277631c754c9ebe5092/confluent_kafka-2.13.2-cp312-cp312-win_amd64.whl", hash = "sha256:44496777ff0104421b8f4bb269728e8a5e772c09f34ae813bc47110e0172ebe0", size = 4097817, upload-time = "2026-03-02T12:52:58.831Z" }, +] + [[package]] name = "contourpy" version = "1.3.2" @@ -2225,16 +2243,22 @@ wheels = [ [[package]] name = "ligo-gracedb" -version = "2.14.2" +version = "2.15.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "igwn-auth-utils" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ad/63/d42ee081193d7bbd0f449d3eb46566669207d35117e68d74518cf5e41c27/ligo_gracedb-2.14.2.tar.gz", hash = "sha256:cd93fd50c7a999f88ef969826e678bdf99e242090b6d439d230e5f4f8043f39b", size = 2396894, upload-time = "2025-04-11T15:44:18.597Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/7a/6fa8ee52bb99b7d510a51686e7bd935be263031464eb38a710d458893916/ligo_gracedb-2.15.4.tar.gz", hash = "sha256:06c246f76bdded0bbb4c0471e77a69b0da95cdc97f2c6736a3966ad77667a8bb", size = 2444113, upload-time = "2026-03-25T20:23:00.73Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d6/ad/e00cc582adb7d0f41363e94a894624a357bfeded7aa143ed9c1820ef1d2e/ligo_gracedb-2.14.2-py3-none-any.whl", hash = "sha256:6463af190ac27c136df0a074c4ea1261b407fc568a2a71a4fa04686797e725f5", size = 2454508, upload-time = "2025-04-11T15:44:17.177Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a6/075ee856218dcd9597fbfb829a60b7daf55ed877f2c0493735d7d9a8ea91/ligo_gracedb-2.15.4-py3-none-any.whl", hash = "sha256:3ac1b4dcac86d7f533b931a4ab140f37a09497eea2fbb314c40c34aac57211c2", size = 2516860, upload-time = "2026-03-25T20:22:59.234Z" }, +] + +[package.optional-dependencies] +kafka = [ + { name = "certifi" }, + { name = "confluent-kafka" }, ] [[package]] @@ -2505,7 +2529,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.13" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2514,9 +2538,26 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/af/bab88ca4f54386735a64502a990f1bc4edb6a0353aa2b910efd9aa244919/ml4gw-0.7.13.tar.gz", hash = "sha256:4e6264fcdb9cbf5ed6a83910a231a946770687bbc6c576f8e6dc811af7102f24", size = 121893, upload-time = "2026-04-07T00:22:25.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/7e/1499d3835e9fb9adeeb8b5b5107dcdee79c37d6a61cd9db912fe59100efa/ml4gw-0.7.13-py3-none-any.whl", hash = "sha256:53132028674bb079759f97b664ba8d1ab1e53ecb21a50660b9015ae7c518bde0", size = 132944, upload-time = "2026-04-07T00:22:24.325Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, +] + +[[package]] +name = "mldatafind" +version = "0.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3" }, + { name = "cloudpathlib" }, + { name = "gwpy" }, + { name = "htgettoken" }, + { name = "law" }, + { name = "luigi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e6/fe/8600fae835cad7c850da58c8b8765c8cefac25dfbf5e8cf87476ad6218cc/mldatafind-0.1.8.tar.gz", hash = "sha256:7e33d5da1ebeff5908c6551c5c6ae0ac02c6543f908657da3be8b4d2020ea8e8", size = 119699, upload-time = "2025-06-21T13:36:12.607Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/8f/cb9af0a0c61148afb3ac403c996da28d0ee598366ac1abcf0d6c36283fca/mldatafind-0.1.8-py3-none-any.whl", hash = "sha256:04c641f046caeb8a6e14ae9eb79d007c85aa97753fbf2b1cb594ecfcd28a4251", size = 15078, upload-time = "2025-06-21T13:36:11.37Z" }, ] [[package]] @@ -3077,7 +3118,7 @@ dependencies = [ { name = "gwpy" }, { name = "jsonargparse" }, { name = "ledger" }, - { name = "ligo-gracedb" }, + { name = "ligo-gracedb", extra = ["kafka"] }, { name = "ligo-skymap" }, { name = "matplotlib" }, { name = "ml4gw" }, @@ -3106,7 +3147,7 @@ requires-dist = [ { name = "gwpy", specifier = ">=3.0.12" }, { name = "jsonargparse", specifier = ">=4.18.0" }, { name = "ledger", editable = "../../libs/ledger" }, - { name = "ligo-gracedb", specifier = ">=2.14.1" }, + { name = "ligo-gracedb", extras = ["kafka"], specifier = ">=2.15.4" }, { name = "ligo-skymap", specifier = ">=2.4.0,<3" }, { name = "matplotlib", specifier = ">=3.9.4" }, { name = "ml4gw", specifier = ">=0.7.4" }, @@ -4964,7 +5005,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/configs/time_heterodyne.yaml b/projects/train/configs/time_heterodyne.yaml new file mode 100644 index 000000000..d2928f781 --- /dev/null +++ b/projects/train/configs/time_heterodyne.yaml @@ -0,0 +1,148 @@ +# commented args represent values filled out +# by train task at run time. To build a functional +# standalone config, add these in. + +# To start training from a checkpoint, uncomment the below argument +# and specify a path to the desired checkpoint +# ckpt_path: "" +model: + class_path: train.model.SupervisedAframe + init_args: + # architecture + arch: + class_path: architectures.supervised.SupervisedHeterodyneTimeDomainResNet + init_args: + layers: [3, 4, 6, 3] + kernel_size: 25 + norm_layer: + class_path: ml4gw.nn.norm.GroupNorm1DGetter + init_args: + groups: 16 + metric: + class_path: train.metrics.TimeSlideAUROC + init_args: + max_fpr: 1e-3 + pool_length: 8 + + # optimization params + weight_decay: 3e-5 + learning_rate: 4e-4 + pct_lr_ramp: 0.115 + # early stop +data: + class_path: train.data.supervised.HeterodyneTimeDomainSupervisedAframeDataset + init_args: + # loading args + background_dir: + waveforms_dir: + ifos: [H1,L1] + sample_rate: 2048 + batches_per_epoch: 100 + num_files_per_batch: 10 + chunk_size: 1000 + chunks_per_epoch: 20 + # preprocessing args + batch_size: 100 + kernel_length: 8 + psd_length: 20 + fduration: 2 + # augmentation args + waveform_prob: 0.277 + swap_prob: 0.014 + mute_prob: 0.055 + left_pad: 6.95 + right_pad: 0.05 + # heterodyne preprocessing args + chirp_mass_low: 1 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: null + # highpass: + # lowpass: + fftlength: 2 + snr_sampler: + class_path: ml4gw.distributions.PowerLaw + init_args: + minimum: 8 + maximum: 100 + index: -3 + # curriculum learning for snr sampler + # snr_sampler: + # class_path: train.augmentations.SnrSampler + # init_args: + # max_min_snr: 30 + # min_min_snr: 8 + # max_snr: 100 + # alpha: -3 + # decay_steps: 600 + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + training_waveform_path: + val_waveform_file: + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 6.28318 + validate_args: false + # validation args + valid_stride: 0.5 + num_valid_views: 5 + valid_livetime: 57600 + +trainer: + # by default, use a local CSV logger. + # note that you can use multiple loggers! + # Options in train task for appending + # a wandb logger for remote logging. + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: + name: + version: + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: "valid_auroc" + mode: "max" + patience: 500 + # custom model checkpoint for saving and + # tracing best model at end of traiing + # that will be used for downstream export + - class_path: train.callbacks.ModelCheckpoint + init_args: + monitor: "valid_auroc" + mode: "max" + save_top_k: 10 + save_last: true + auto_insert_metric_name: false + - class_path: train.callbacks.SaveAugmentedBatch + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "step" + # uncomment below if you want to profile + # profiler: + # class_path: lightning.pytorch.profilers.PyTorchProfiler + # dict_kwargs: + # profile_memory: true + # devices: + # strategy: set to ddp if len(devices) > 1 + #precision: 16-mixed + accelerator: auto + max_epochs: 1500 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_progress_bar: true + benchmark: true diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index 133569ab3..a5c05b4e8 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=2", "utils", - "ml4gw>=0.7.7", + "ml4gw>=0.8.0", "aframe", "ledger", "priors", diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index 3953dbc81..e0dfd1b2e 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -57,6 +57,16 @@ def add_arguments_to_parser(self, parser): except Exception: pass + # TODO: This is a workaround for linking num_chirp_masses between + # the model and the architecture for in_channels. + try: + parser.link_arguments( + "data.init_args.num_chirp_masses", + "model.init_args.arch.init_args.num_chirp_masses", + ) + except Exception: + pass + parser.link_arguments( "data.init_args.sample_rate", "data.init_args.waveform_sampler.init_args.sample_rate", diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index 41b42c9b9..947d9122e 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -6,3 +6,9 @@ SpectrogramDomainSupervisedAframeDataset, TimeSpectrogramDomainSupervisedAframeDataset, ) +from .multimodal import MultiModalSupervisedAframeDataset +from .supervised import SupervisedAframeDataset +from .time_domain import ( + TimeDomainSupervisedAframeDataset, + HeterodyneTimeDomainSupervisedAframeDataset, +) diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index a26806614..c31bbbafe 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -1,6 +1,9 @@ +import math import torch +from typing import Literal from train.data.supervised.supervised import SupervisedAframeDataset +from ml4gw.transforms import Heterodyne class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): @@ -20,3 +23,118 @@ def inject(self, X, waveforms=None): X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) return X, y + + +class HeterodyneTimeDomainSupervisedAframeDataset(SupervisedAframeDataset): + """ + A derived class from BaseAframeDataset and SupervisedAframeDataset, it + applies heterodyning to strain data and returns heterodyned timeseries + for loading data to train Aframe models. If `keep_last_n_seconds` is + passed, returns only the final portion of the heterodyned strain. + + Args: + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. + """ + + def __init__( + self, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + self.keep_last_n_seconds = keep_last_n_seconds + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * self.hparams.sample_rate + ) + + def build_transforms(self, *args, **kwargs): + super().build_transforms(*args, **kwargs) + self.heterodyne_transform = Heterodyne( + sample_rate=self.hparams.sample_rate, + kernel_length=self.hparams.kernel_length, + chirp_mass=self.chirp_mass_grid, + return_type="time", + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def build_val_batches(self, background, signals): + X_bg, X_inj, psds = super().build_val_batches(background, signals) + X_bg = self.whitener(X_bg, psds) + X_bg = self.heterodyne_transform(X_bg) + _B_bg, _C_bg, _M_bg, _T_bg = X_bg.shape + X_bg = X_bg.view(_B_bg, _C_bg * _M_bg, _T_bg) + # whiten each view of injections + X_fg = [] + for inj in X_inj: + inj = self.whitener(inj, psds) + inj = self.heterodyne_transform(inj) + X_fg.append(inj) + X_fg = torch.stack(X_fg) + _V_fg, _B_fg, _C_fg, _M_fg, _T_fg = X_fg.shape + X_fg = X_fg.view(_V_fg, _B_fg, _C_fg * _M_fg, _T_fg) + + if self.keep_last_n_seconds is not None: + return X_bg[..., -self.keep_last_n_samples :], X_fg[ + ..., -self.keep_last_n_samples : + ] + else: + return X_bg, X_fg + + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) + X = self.whitener(X, psds) + X = self.heterodyne_transform(X) + _B, _C, _M, _T = X.shape + X = X.view(_B, _C * _M, _T) + + if self.keep_last_n_seconds is not None: + return X[..., -self.keep_last_n_samples :], y + else: + return X, y diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 09499776a..ec3713e33 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -2100,7 +2100,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2109,9 +2109,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -4101,7 +4101,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = ">=2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -4237,7 +4237,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/uv.lock b/uv.lock index dd41a31ba..729278c81 100644 --- a/uv.lock +++ b/uv.lock @@ -1723,7 +1723,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1732,9 +1732,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -3487,7 +3487,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ]