Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/architectures/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
SupervisedSpectrogramDomainResNet,
SupervisedTimeDomainResNet,
SupervisedTimeSpectrogramResNet,
SupervisedHeterodyneTimeDomainResNet,
)
40 changes: 40 additions & 0 deletions libs/architectures/architectures/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,43 @@ def forward(self, X, X_spec):
time_domain_output = self.time_domain_resnet(X)
spec_domain_output = self.spectrogram_resnet(X_spec)
return time_domain_output, spec_domain_output


class SupervisedHeterodyneTimeDomainResNet(SupervisedArchitecture):
"""
Time Domain ResNet that processes a Heterodyned timeseries.

Args:
num_chirp_masses (int):
Number of chirp masses used to define the input channel
dimension (in_channels = num_ifos x num_chirp_masses).
"""

def __init__(
self,
num_ifos: int,
num_chirp_masses: int,
layers: list[int],
kernel_size: int = 3,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
stride_type: Optional[list[Literal["stride", "dilation"]]] = None,
norm_layer: Optional[NormLayer] = None,
**kwargs,
) -> None:
super().__init__()
self.time_domain_resnet = ResNet1D(
in_channels=num_ifos * num_chirp_masses,
layers=layers,
classes=1,
kernel_size=kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)

def forward(self, X):
return self.time_domain_resnet(X)
2 changes: 1 addition & 1 deletion libs/utils/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies = [
"h5py~=3.6",
"numpy>=1.26.4,<2",
"s3fs>=2024,<2025",
"ml4gw>=0.7.10",
"ml4gw>=0.8.0",
"astropy>=6.0.1",
]

Expand Down
55 changes: 55 additions & 0 deletions libs/utils/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
MultiModalPreprocessor,
TimeSpectrogramPreprocessor,
)
from utils.augmentation import HeterodyneAugmentor
from ml4gw.transforms import Heterodyne
import torch
import pytest

Expand Down Expand Up @@ -247,6 +249,59 @@ def simple_augmentor(x):
kernels_aug = whitener_aug(x)
assert torch.allclose(kernels_aug, kernels * 2)

def test_with_heterodyne_augmentor(self):
heterodyne_augmentor = HeterodyneAugmentor(
sample_rate=self.sample_rate,
kernel_length=self.kernel_length,
chirp_mass_low=1.0,
chirp_mass_high=2.5,
num_chirp_masses=10,
chirp_mass_spacing="log",
keep_last_n_seconds=None,
)

whitener = BatchWhitener(
kernel_length=self.kernel_length,
sample_rate=self.sample_rate,
inference_sampling_rate=self.inference_sampling_rate,
batch_size=self.batch_size,
fduration=self.fduration,
fftlength=self.fftlength,
)

whitener_aug = BatchWhitener(
kernel_length=self.kernel_length,
sample_rate=self.sample_rate,
inference_sampling_rate=self.inference_sampling_rate,
batch_size=self.batch_size,
fduration=self.fduration,
fftlength=self.fftlength,
augmentor=heterodyne_augmentor,
)

channels = 2
total_samples = int(
(
(self.batch_size - 1) * whitener.stride_size
+ whitener.kernel_size
)
* 2
)
x = torch.randn(channels, total_samples)

heterodyne = Heterodyne(
sample_rate=self.sample_rate,
kernel_length=self.kernel_length,
chirp_mass=heterodyne_augmentor.chirp_mass_grid,
return_type="time",
)

kernels_heterodyned = heterodyne(whitener(x))
_B, _C, _M, _T = kernels_heterodyned.shape
kernels = kernels_heterodyned.reshape(_B, _C * _M, _T)
kernels_aug = whitener_aug(x)
assert torch.allclose(kernels_aug, kernels)


class TestMultiModalPreprocessor:
"""Test suite for MultiModalPreprocessor module."""
Expand Down
135 changes: 135 additions & 0 deletions libs/utils/utils/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import math
import torch
from torch import Tensor
from typing import Literal

from ml4gw.transforms import Heterodyne


class HeterodyneAugmentor(torch.nn.Module):
"""
Apply a heterodyne transform over a grid of chirp masses to a batch
of time-series data.

Args:
sample_rate (float): Input sampling rate in Hz.
kernel_length (float): Length of output kernels in seconds.
chirp_mass_low (float):
Lower bound of chirp mass range (in solar masses).
chirp_mass_high (float):
Upper bound of chirp mass range (in solar masses).
num_chirp_masses (int):
Number of chirp mass samples to generate.
chirp_mass_spacing (Literal["linear", "log"]):
Spacing of chirp mass grid. Use "linear" for evenly spaced
values or "log" for logarithmic spacing.
keep_last_n_seconds (float):
If provided, only the last `n` seconds of the kernel_length are
returned. Otherwise, the full kernel_length is returned.
Shape:
Input: (batch_size, channels, time)
Output: (batch_size, channels * num_chirp_masses, time_out)

where `time_out = time` unless `keep_last_n_seconds` is set.

Note:
The output shape of the `BatchWhitener` is
(batch_size, channels, kernel_size).
The `HeterodyneAugmentor` changes the output shape to
(batch_size, channels * num_chirp_masses, time_out)
where time_out is determined by the `keep_last_n_seconds` parameter.

Example:
>>> augmentor = HeterodyneAugmentor(
... sample_rate=2048, kernel_length=8, chirp_mass_low=1.0,
... chirp_mass_high=2.5, num_chirp_masses=100,
... chirp_mass_spacing="log", keep_last_n_seconds=4.0,
... )
>>> x = torch.randn(8, 2, 16384) # (batch, channels, time)
>>> y = augmentor(x)
>>> # y: (8, 2 * 100, 8192) since we keep the last 4 seconds at 2048 Hz
"""

def __init__(
self,
sample_rate: float,
kernel_length: float,
chirp_mass_low: float = 1.0,
chirp_mass_high: float = 2.5,
num_chirp_masses: int = 100,
chirp_mass_spacing: Literal["linear", "log"] = "log",
keep_last_n_seconds: float = None,
):
super().__init__()
self.sample_rate = sample_rate
self.kernel_length = kernel_length
self.keep_last_n_seconds = keep_last_n_seconds
self.num_chirp_masses = num_chirp_masses
self.keep_last_n_seconds = keep_last_n_seconds

self.chirp_mass_grid = self._create_chirp_mass_grid(
chirp_mass_low,
chirp_mass_high,
num_chirp_masses,
chirp_mass_spacing,
)

if self.keep_last_n_seconds is not None:
self.keep_last_n_samples = int(
self.keep_last_n_seconds * sample_rate
)

self.heterodyne_transform = Heterodyne(
sample_rate=sample_rate,
kernel_length=kernel_length,
chirp_mass=self.chirp_mass_grid,
return_type="time",
)

def _create_chirp_mass_grid(
self,
chirp_mass_low: float,
chirp_mass_high: float,
num_chirp_masses: int,
chirp_mass_spacing: Literal["linear", "log"],
) -> torch.Tensor:
if chirp_mass_spacing == "linear":
return torch.linspace(
chirp_mass_low, chirp_mass_high, num_chirp_masses
)
elif chirp_mass_spacing == "log":
return torch.logspace(
math.log10(chirp_mass_low),
math.log10(chirp_mass_high),
num_chirp_masses,
)
else:
raise ValueError(
f"Invalid chirp mass spacing: {chirp_mass_spacing}"
)

def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Input data of shape (batch, channels, time).

Returns:
Tensor: Output data of shape
(batch, channels * num_chirp_masses, time_out),
where `time_out` is either the same as input time
or determined by `keep_last_n_seconds`.
"""
_B, _C, _T = x.shape
x_heterodyned = torch.empty((_B, _C * self.num_chirp_masses, _T))
# Heterodyne the whitened timeseries
x = self.heterodyne_transform(x)
# Reshaping x from (batch_size, channels, num_chirp_mass, kernel_size)
# to (batch_size, channels x num_chirp_mass, kernel_size)
x = x.reshape(_B, _C * self.num_chirp_masses, _T)
x_heterodyned[:, :, :] = x
# Returning the desired length of heterodyned strain in the
# time dimension
if self.keep_last_n_seconds is not None:
return x_heterodyned[..., -self.keep_last_n_samples :]
else:
return x_heterodyned
8 changes: 4 additions & 4 deletions libs/utils/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions projects/data/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions projects/export/export_heterodyne.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# commented args (i.e. comments with colon at end
# to illustrate a value needs to be filled out)
# represent values filled out
# by the export task at run time. To build a functional
# standalone config, add these in yourself.

logfile:
weights:
batch_file:
repository_directory:
num_ifos: 2
kernel_length: 4.0
inference_sampling_rate: 4
sample_rate: 2048
batch_size: 128
fduration: 2
psd_length: 64
preprocessor:
class_path: utils.preprocessing.BatchWhitener
init_args:
kernel_length: 4.0
sample_rate: 2048
inference_sampling_rate: 4
batch_size: 128
fduration: 2
fftlength: 2
augmentor:
class_path: utils.augmentation.HeterodyneAugmentor
init_args:
sample_rate: 2048
kernel_length: 4.0
chirp_mass_low: 1.0
chirp_mass_high: 2.5
num_chirp_masses: 100
chirp_mass_spacing: "log"
keep_last_n_seconds: null
highpass: 32.0
streams_per_gpu: 12
# num_outputs:
aframe_instances: 1
platform: TENSORRT
clean: true
verbose: true
2 changes: 1 addition & 1 deletion projects/export/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [{ name = "Ethan Jacob Marx", email = "ethan.marx@ligo.org" }]
requires-python = ">=3.10,<3.13"
license = "MIT"
dependencies = [
"ml4gw>=0.7.7",
"ml4gw>=0.8.0",
"boto3~=1.30",
"fsspec[s3]>=2024,<2025",
"ml4gw-hermes[torch]>=0.2.1",
Expand Down
Loading
Loading