diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 7ff6db14c..350160f02 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,4 +6,5 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, SupervisedTimeSpectrogramResNet, + SupervisedHeterodyneTimeDomainResNet, ) diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 1ade07fff..991871830 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -280,3 +280,43 @@ def forward(self, X, X_spec): time_domain_output = self.time_domain_resnet(X) spec_domain_output = self.spectrogram_resnet(X_spec) return time_domain_output, spec_domain_output + + +class SupervisedHeterodyneTimeDomainResNet(SupervisedArchitecture): + """ + Time Domain ResNet that processes a Heterodyned timeseries. + + Args: + num_chirp_masses (int): + Number of chirp masses used to define the input channel + dimension (in_channels = num_ifos x num_chirp_masses). + """ + + def __init__( + self, + num_ifos: int, + num_chirp_masses: int, + layers: list[int], + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ) -> None: + super().__init__() + self.time_domain_resnet = ResNet1D( + in_channels=num_ifos * num_chirp_masses, + layers=layers, + classes=1, + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + def forward(self, X): + return self.time_domain_resnet(X) diff --git a/libs/utils/pyproject.toml b/libs/utils/pyproject.toml index dcfdde2c2..70dbe8069 100644 --- a/libs/utils/pyproject.toml +++ b/libs/utils/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "h5py~=3.6", "numpy>=1.26.4,<2", "s3fs>=2024,<2025", - "ml4gw>=0.7.10", + "ml4gw>=0.8.0", "astropy>=6.0.1", ] diff --git a/libs/utils/tests/test_preprocessing.py b/libs/utils/tests/test_preprocessing.py index 1a85872f2..9174f1e3e 100644 --- a/libs/utils/tests/test_preprocessing.py +++ b/libs/utils/tests/test_preprocessing.py @@ -5,6 +5,8 @@ MultiModalPreprocessor, TimeSpectrogramPreprocessor, ) +from utils.augmentation import HeterodyneAugmentor +from ml4gw.transforms import Heterodyne import torch import pytest @@ -247,6 +249,59 @@ def simple_augmentor(x): kernels_aug = whitener_aug(x) assert torch.allclose(kernels_aug, kernels * 2) + def test_with_heterodyne_augmentor(self): + heterodyne_augmentor = HeterodyneAugmentor( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass_low=1.0, + chirp_mass_high=2.5, + num_chirp_masses=10, + chirp_mass_spacing="log", + keep_last_n_seconds=None, + ) + + whitener = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + ) + + whitener_aug = BatchWhitener( + kernel_length=self.kernel_length, + sample_rate=self.sample_rate, + inference_sampling_rate=self.inference_sampling_rate, + batch_size=self.batch_size, + fduration=self.fduration, + fftlength=self.fftlength, + augmentor=heterodyne_augmentor, + ) + + channels = 2 + total_samples = int( + ( + (self.batch_size - 1) * whitener.stride_size + + whitener.kernel_size + ) + * 2 + ) + x = torch.randn(channels, total_samples) + + heterodyne = Heterodyne( + sample_rate=self.sample_rate, + kernel_length=self.kernel_length, + chirp_mass=heterodyne_augmentor.chirp_mass_grid, + return_type="time", + ) + + kernels_heterodyned = heterodyne(whitener(x)) + _B, _C, _M, _T = kernels_heterodyned.shape + kernels = kernels_heterodyned.reshape(_B, _C * _M, _T) + kernels_aug = whitener_aug(x) + assert torch.allclose(kernels_aug, kernels) + class TestMultiModalPreprocessor: """Test suite for MultiModalPreprocessor module.""" diff --git a/libs/utils/utils/augmentation.py b/libs/utils/utils/augmentation.py new file mode 100644 index 000000000..7680b7a0e --- /dev/null +++ b/libs/utils/utils/augmentation.py @@ -0,0 +1,135 @@ +import math +import torch +from torch import Tensor +from typing import Literal + +from ml4gw.transforms import Heterodyne + + +class HeterodyneAugmentor(torch.nn.Module): + """ + Apply a heterodyne transform over a grid of chirp masses to a batch + of time-series data. + + Args: + sample_rate (float): Input sampling rate in Hz. + kernel_length (float): Length of output kernels in seconds. + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. + Shape: + Input: (batch_size, channels, time) + Output: (batch_size, channels * num_chirp_masses, time_out) + + where `time_out = time` unless `keep_last_n_seconds` is set. + + Note: + The output shape of the `BatchWhitener` is + (batch_size, channels, kernel_size). + The `HeterodyneAugmentor` changes the output shape to + (batch_size, channels * num_chirp_masses, time_out) + where time_out is determined by the `keep_last_n_seconds` parameter. + + Example: + >>> augmentor = HeterodyneAugmentor( + ... sample_rate=2048, kernel_length=8, chirp_mass_low=1.0, + ... chirp_mass_high=2.5, num_chirp_masses=100, + ... chirp_mass_spacing="log", keep_last_n_seconds=4.0, + ... ) + >>> x = torch.randn(8, 2, 16384) # (batch, channels, time) + >>> y = augmentor(x) + >>> # y: (8, 2 * 100, 8192) since we keep the last 4 seconds at 2048 Hz + """ + + def __init__( + self, + sample_rate: float, + kernel_length: float, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = None, + ): + super().__init__() + self.sample_rate = sample_rate + self.kernel_length = kernel_length + self.keep_last_n_seconds = keep_last_n_seconds + self.num_chirp_masses = num_chirp_masses + self.keep_last_n_seconds = keep_last_n_seconds + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * sample_rate + ) + + self.heterodyne_transform = Heterodyne( + sample_rate=sample_rate, + kernel_length=kernel_length, + chirp_mass=self.chirp_mass_grid, + return_type="time", + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input data of shape (batch, channels, time). + + Returns: + Tensor: Output data of shape + (batch, channels * num_chirp_masses, time_out), + where `time_out` is either the same as input time + or determined by `keep_last_n_seconds`. + """ + _B, _C, _T = x.shape + x_heterodyned = torch.empty((_B, _C * self.num_chirp_masses, _T)) + # Heterodyne the whitened timeseries + x = self.heterodyne_transform(x) + # Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size) + # to (batch_size, channels x num_chirp_mass, kernel_size) + x = x.reshape(_B, _C * self.num_chirp_masses, _T) + x_heterodyned[:, :, :] = x + # Returning the desired length of heterodyned strain in the + # time dimension + if self.keep_last_n_seconds is not None: + return x_heterodyned[..., -self.keep_last_n_samples :] + else: + return x_heterodyned diff --git a/libs/utils/uv.lock b/libs/utils/uv.lock index 5d5bc6939..1cfd3e70d 100644 --- a/libs/utils/uv.lock +++ b/libs/utils/uv.lock @@ -456,7 +456,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -465,9 +465,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1124,7 +1124,7 @@ dev = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/data/uv.lock b/projects/data/uv.lock index d456db315..4008a1c04 100644 --- a/projects/data/uv.lock +++ b/projects/data/uv.lock @@ -333,6 +333,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -1494,7 +1496,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1503,9 +1505,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -2706,7 +2708,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/export/export_heterodyne.yaml b/projects/export/export_heterodyne.yaml new file mode 100644 index 000000000..1e3414634 --- /dev/null +++ b/projects/export/export_heterodyne.yaml @@ -0,0 +1,43 @@ +# commented args (i.e. comments with colon at end +# to illustrate a value needs to be filled out) +# represent values filled out +# by the export task at run time. To build a functional +# standalone config, add these in yourself. + +logfile: +weights: +batch_file: +repository_directory: +num_ifos: 2 +kernel_length: 4.0 +inference_sampling_rate: 4 +sample_rate: 2048 +batch_size: 128 +fduration: 2 +psd_length: 64 +preprocessor: + class_path: utils.preprocessing.BatchWhitener + init_args: + kernel_length: 4.0 + sample_rate: 2048 + inference_sampling_rate: 4 + batch_size: 128 + fduration: 2 + fftlength: 2 + augmentor: + class_path: utils.augmentation.HeterodyneAugmentor + init_args: + sample_rate: 2048 + kernel_length: 4.0 + chirp_mass_low: 1.0 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: null + highpass: 32.0 +streams_per_gpu: 12 +# num_outputs: +aframe_instances: 1 +platform: TENSORRT +clean: true +verbose: true diff --git a/projects/export/pyproject.toml b/projects/export/pyproject.toml index 7be57c0ae..b95355d18 100644 --- a/projects/export/pyproject.toml +++ b/projects/export/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }] requires-python = ">=3.10,<3.13" license = "MIT" dependencies = [ - "ml4gw>=0.7.7", + "ml4gw>=0.8.0", "boto3~=1.30", "fsspec[s3]>=2024,<2025", "ml4gw-hermes[torch]>=0.2.1", diff --git a/projects/export/uv.lock b/projects/export/uv.lock index 53c7bf108..a57e9deec 100644 --- a/projects/export/uv.lock +++ b/projects/export/uv.lock @@ -528,7 +528,7 @@ requires-dist = [ { name = "boto3", specifier = "~=1.30" }, { name = "fsspec", extras = ["s3"], specifier = ">=2024,<2025" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "ml4gw-hermes", extras = ["torch"], specifier = ">=0.2.1" }, { name = "nvidia-cudnn-cu11", specifier = "==8.9.6.50" }, { name = "tensorrt", specifier = "==8.5.2.2" }, @@ -972,7 +972,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.10" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -981,9 +981,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/4b/d04a579fe0009f29f336af329d231dfd0da32c1f03560cbfdcd1d3081b86/ml4gw-0.7.10.tar.gz", hash = "sha256:20c54524b9f44669ef2b320832c62fac9766d2374019cfd798ec89a83adf1e99", size = 119139, upload-time = "2025-12-03T01:34:50.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/a8/174f4112c2dae1c61f0ed1c5eb32b7b7dd0c34b8050ec1a67d45480f34b3/ml4gw-0.7.10-py3-none-any.whl", hash = "sha256:f9edaa40365ee4441ffb42b4bf542a70522990495659bef441fbb2eba0efbcf4", size = 129192, upload-time = "2025-12-03T01:34:49.491Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -1980,7 +1980,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/online/uv.lock b/projects/online/uv.lock index f97583ee8..4eaab91bb 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -478,6 +478,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2276,7 +2278,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2285,9 +2287,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -4432,7 +4434,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/projects/train/configs/time_heterodyne.yaml b/projects/train/configs/time_heterodyne.yaml new file mode 100644 index 000000000..d2928f781 --- /dev/null +++ b/projects/train/configs/time_heterodyne.yaml @@ -0,0 +1,148 @@ +# commented args represent values filled out +# by train task at run time. To build a functional +# standalone config, add these in. + +# To start training from a checkpoint, uncomment the below argument +# and specify a path to the desired checkpoint +# ckpt_path: "" +model: + class_path: train.model.SupervisedAframe + init_args: + # architecture + arch: + class_path: architectures.supervised.SupervisedHeterodyneTimeDomainResNet + init_args: + layers: [3, 4, 6, 3] + kernel_size: 25 + norm_layer: + class_path: ml4gw.nn.norm.GroupNorm1DGetter + init_args: + groups: 16 + metric: + class_path: train.metrics.TimeSlideAUROC + init_args: + max_fpr: 1e-3 + pool_length: 8 + + # optimization params + weight_decay: 3e-5 + learning_rate: 4e-4 + pct_lr_ramp: 0.115 + # early stop +data: + class_path: train.data.supervised.HeterodyneTimeDomainSupervisedAframeDataset + init_args: + # loading args + background_dir: + waveforms_dir: + ifos: [H1,L1] + sample_rate: 2048 + batches_per_epoch: 100 + num_files_per_batch: 10 + chunk_size: 1000 + chunks_per_epoch: 20 + # preprocessing args + batch_size: 100 + kernel_length: 8 + psd_length: 20 + fduration: 2 + # augmentation args + waveform_prob: 0.277 + swap_prob: 0.014 + mute_prob: 0.055 + left_pad: 6.95 + right_pad: 0.05 + # heterodyne preprocessing args + chirp_mass_low: 1 + chirp_mass_high: 2.5 + num_chirp_masses: 100 + chirp_mass_spacing: "log" + keep_last_n_seconds: null + # highpass: + # lowpass: + fftlength: 2 + snr_sampler: + class_path: ml4gw.distributions.PowerLaw + init_args: + minimum: 8 + maximum: 100 + index: -3 + # curriculum learning for snr sampler + # snr_sampler: + # class_path: train.augmentations.SnrSampler + # init_args: + # max_min_snr: 30 + # min_min_snr: 8 + # max_snr: 100 + # alpha: -3 + # decay_steps: 600 + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader + init_args: + training_waveform_path: + val_waveform_file: + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 6.28318 + validate_args: false + # validation args + valid_stride: 0.5 + num_valid_views: 5 + valid_livetime: 57600 + +trainer: + # by default, use a local CSV logger. + # note that you can use multiple loggers! + # Options in train task for appending + # a wandb logger for remote logging. + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: + name: + version: + + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: "valid_auroc" + mode: "max" + patience: 500 + # custom model checkpoint for saving and + # tracing best model at end of traiing + # that will be used for downstream export + - class_path: train.callbacks.ModelCheckpoint + init_args: + monitor: "valid_auroc" + mode: "max" + save_top_k: 10 + save_last: true + auto_insert_metric_name: false + - class_path: train.callbacks.SaveAugmentedBatch + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "step" + # uncomment below if you want to profile + # profiler: + # class_path: lightning.pytorch.profilers.PyTorchProfiler + # dict_kwargs: + # profile_memory: true + # devices: + # strategy: set to ddp if len(devices) > 1 + #precision: 16-mixed + accelerator: auto + max_epochs: 1500 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_progress_bar: true + benchmark: true diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index 8adbbc07e..d81f1834a 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.7.7", + "ml4gw>=0.8.0", "aframe", "ledger", "priors", diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index 3953dbc81..e0dfd1b2e 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -57,6 +57,16 @@ def add_arguments_to_parser(self, parser): except Exception: pass + # TODO: This is a workaround for linking num_chirp_masses between + # the model and the architecture for in_channels. + try: + parser.link_arguments( + "data.init_args.num_chirp_masses", + "model.init_args.arch.init_args.num_chirp_masses", + ) + except Exception: + pass + parser.link_arguments( "data.init_args.sample_rate", "data.init_args.waveform_sampler.init_args.sample_rate", diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index a277db1f1..ff3ee5da0 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -5,4 +5,7 @@ ) from .multimodal import MultiModalSupervisedAframeDataset from .supervised import SupervisedAframeDataset -from .time_domain import TimeDomainSupervisedAframeDataset +from .time_domain import ( + TimeDomainSupervisedAframeDataset, + HeterodyneTimeDomainSupervisedAframeDataset, +) diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index a26806614..c31bbbafe 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -1,6 +1,9 @@ +import math import torch +from typing import Literal from train.data.supervised.supervised import SupervisedAframeDataset +from ml4gw.transforms import Heterodyne class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): @@ -20,3 +23,118 @@ def inject(self, X, waveforms=None): X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) return X, y + + +class HeterodyneTimeDomainSupervisedAframeDataset(SupervisedAframeDataset): + """ + A derived class from BaseAframeDataset and SupervisedAframeDataset, it + applies heterodyning to strain data and returns heterodyned timeseries + for loading data to train Aframe models. If `keep_last_n_seconds` is + passed, returns only the final portion of the heterodyned strain. + + Args: + chirp_mass_low (float): + Lower bound of chirp mass range (in solar masses). + chirp_mass_high (float): + Upper bound of chirp mass range (in solar masses). + num_chirp_masses (int): + Number of chirp mass samples to generate. + chirp_mass_spacing (Literal["linear", "log"]): + Spacing of chirp mass grid. Use "linear" for evenly spaced + values or "log" for logarithmic spacing. + keep_last_n_seconds (float): + If provided, only the last `n` seconds of the kernel_length are + returned. Otherwise, the full kernel_length is returned. + """ + + def __init__( + self, + chirp_mass_low: float = 1.0, + chirp_mass_high: float = 2.5, + num_chirp_masses: int = 100, + chirp_mass_spacing: Literal["linear", "log"] = "log", + keep_last_n_seconds: float = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.chirp_mass_grid = self._create_chirp_mass_grid( + chirp_mass_low, + chirp_mass_high, + num_chirp_masses, + chirp_mass_spacing, + ) + + self.keep_last_n_seconds = keep_last_n_seconds + + if self.keep_last_n_seconds is not None: + self.keep_last_n_samples = int( + self.keep_last_n_seconds * self.hparams.sample_rate + ) + + def build_transforms(self, *args, **kwargs): + super().build_transforms(*args, **kwargs) + self.heterodyne_transform = Heterodyne( + sample_rate=self.hparams.sample_rate, + kernel_length=self.hparams.kernel_length, + chirp_mass=self.chirp_mass_grid, + return_type="time", + ) + + def _create_chirp_mass_grid( + self, + chirp_mass_low: float, + chirp_mass_high: float, + num_chirp_masses: int, + chirp_mass_spacing: Literal["linear", "log"], + ) -> torch.Tensor: + if chirp_mass_spacing == "linear": + return torch.linspace( + chirp_mass_low, chirp_mass_high, num_chirp_masses + ) + elif chirp_mass_spacing == "log": + return torch.logspace( + math.log10(chirp_mass_low), + math.log10(chirp_mass_high), + num_chirp_masses, + ) + else: + raise ValueError( + f"Invalid chirp mass spacing: {chirp_mass_spacing}" + ) + + def build_val_batches(self, background, signals): + X_bg, X_inj, psds = super().build_val_batches(background, signals) + X_bg = self.whitener(X_bg, psds) + X_bg = self.heterodyne_transform(X_bg) + _B_bg, _C_bg, _M_bg, _T_bg = X_bg.shape + X_bg = X_bg.view(_B_bg, _C_bg * _M_bg, _T_bg) + # whiten each view of injections + X_fg = [] + for inj in X_inj: + inj = self.whitener(inj, psds) + inj = self.heterodyne_transform(inj) + X_fg.append(inj) + X_fg = torch.stack(X_fg) + _V_fg, _B_fg, _C_fg, _M_fg, _T_fg = X_fg.shape + X_fg = X_fg.view(_V_fg, _B_fg, _C_fg * _M_fg, _T_fg) + + if self.keep_last_n_seconds is not None: + return X_bg[..., -self.keep_last_n_samples :], X_fg[ + ..., -self.keep_last_n_samples : + ] + else: + return X_bg, X_fg + + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) + X = self.whitener(X, psds) + X = self.heterodyne_transform(X) + _B, _C, _M, _T = X.shape + X = X.view(_B, _C * _M, _T) + + if self.keep_last_n_seconds is not None: + return X[..., -self.keep_last_n_samples :], y + else: + return X, y diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 906633b6b..c8d33f285 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -459,6 +459,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851, upload-time = "2024-08-23T15:22:07.895Z" }, { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120, upload-time = "2024-08-23T15:22:09.942Z" }, { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801, upload-time = "2024-08-23T15:54:20.941Z" }, + { url = "https://files.pythonhosted.org/packages/73/b9/e8a78c082d8708ea4cc9c65b53dfed9d1d6bc9b3a44d712811b9e55022ee/bilby_cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d39ad43c8962a32b7c561ee07f0f9fb9e656a7847b30176695007b31426d2474", size = 363731, upload-time = "2026-02-23T16:52:49.722Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a2/6a8e2a8a0721b758745e2a35f91c5ff380cf0f795408bc74b9aa8c589f0a/bilby_cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:9aabbcce359c63c78cf1c1bf4d714c438a2936ddd4e061fe90b3320415dd12f6", size = 361366, upload-time = "2026-02-23T16:52:50.964Z" }, ] [[package]] @@ -2319,7 +2321,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2328,9 +2330,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -4370,7 +4372,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.7" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -4509,7 +4511,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] diff --git a/uv.lock b/uv.lock index ec36dea15..b3700fae0 100644 --- a/uv.lock +++ b/uv.lock @@ -1299,7 +1299,7 @@ name = "importlib-metadata" version = "8.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp" }, + { name = "zipp", marker = "python_full_version < '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/33/08/c1395a292bb23fd03bdf572a1357c5a733d3eecbab877641ceacab23db6e/importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580", size = 55767, upload-time = "2025-01-20T22:21:30.429Z" } wheels = [ @@ -1911,7 +1911,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.11" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -1920,9 +1920,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0a/722f553635ffc91b32623e69a4c93591c11ce2c24a10e4bda35ab0d8e6ae/ml4gw-0.7.11.tar.gz", hash = "sha256:8df9ebecd97ed6a6e8ba07fab40882f5966e646897f5187a9ccf7913faf6464e", size = 119593, upload-time = "2026-01-29T20:34:30.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/56/abb490d353f989802f918ee92cf6c9a37336483aa20c7f17df7730d81744/ml4gw-0.8.0.tar.gz", hash = "sha256:43a2411ae348f8f911fdc0e2defd4fa54370414fa8b51c63518de3cb805754ba", size = 121709, upload-time = "2026-04-17T13:15:20.347Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/7d/f8c3e695d52cd9e70fd3f7bb51efd29848a3eb481dc1b94228f481dd05f8/ml4gw-0.7.11-py3-none-any.whl", hash = "sha256:0a6645f27444d266fb94afe988450bc2d00e24bd70328b0a5903194e1900acdb", size = 129588, upload-time = "2026-01-29T20:34:29.357Z" }, + { url = "https://files.pythonhosted.org/packages/2a/22/64102f10ad7f9043083d8bafcf84b8f2e8abc84bfc214fd04dd8d243ece9/ml4gw-0.8.0-py3-none-any.whl", hash = "sha256:0b4377541d5a90dcf9c728efa4008b4d57b942452acf564856ac58f599273070", size = 132926, upload-time = "2026-04-17T13:15:18.751Z" }, ] [[package]] @@ -3688,7 +3688,7 @@ dependencies = [ requires-dist = [ { name = "astropy", specifier = ">=6.0.1" }, { name = "h5py", specifier = "~=3.6" }, - { name = "ml4gw", specifier = ">=0.7.10" }, + { name = "ml4gw", specifier = ">=0.8.0" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ]