Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions bin/pycbc_inspiral_fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python
"""
pycbc_inspiral_ratio: A pipeline-compatible script for inspiral analysis
using the Ratio Matched Filter method.
Expand Down Expand Up @@ -279,8 +279,10 @@ def main():
delta_f=delta_f,
high_frequency_cutoff=args.high_frequency_cutoff,
fir_fft_length=args.fir_length,
batch_size=args.batch_size
)
batch_size=args.batch_size,
tap_sample_rate=bank.sample_rate,
engine_sample_rate=args.sample_rate
)

power_chisq = pycbc.vetoes.SingleDetPowerChisq(args.chisq_bins,
args.chisq_snr_threshold)
Expand Down
67 changes: 49 additions & 18 deletions pycbc/filter/matched_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ class RatioMatchedFilterControl(object):
Uses mkl_fft for ALL FFT operations to maximize throughput and consistency.
"""

def __init__(self, snr_threshold, delta_f,
high_frequency_cutoff=None, fir_fft_length=4096, batch_size=64):
def __init__(self, snr_threshold, delta_f,
high_frequency_cutoff=None, fir_fft_length=4096, batch_size=64, tap_sample_rate=2048, engine_sample_rate=2048):
self.delta_f = delta_f
self.snr_threshold = snr_threshold
self.f_high = high_frequency_cutoff


self.tap_sr = int(tap_sample_rate)
self.engine_sr = int(engine_sample_rate)

self.threshold_sq = float(snr_threshold**2)

self.fir_fft_len = fir_fft_length
Expand Down Expand Up @@ -67,7 +70,9 @@ def process_segment(self, stilde, psd, ref_template, filters_f, n_taps, indices,
high_frequency_cutoff=self.f_high,
h_norm=h_norm
)
self.ref_snr = snr.numpy() * (norm * stilde.delta_t)

decimate = int(np.round(self.tap_sr / self.engine_sr))
self.ref_snr = snr.numpy() * (norm * stilde.delta_t) / decimate

# 3. Execute Blocked Kernel
local_idxs, t_idxs, snr_vals, tstarts = self._execute_blocked_kernel(
Expand All @@ -86,35 +91,61 @@ def _fft_all_filters(self, taps, counts):
n_filters, n_taps_alloc = taps.shape
filters_f = np.zeros((n_filters, self.fir_fft_len), dtype=np.complex64)

padded_view = self.filters_padded.data
padded_reshaped = padded_view.reshape(self.batch_size, self.fir_fft_len)
# 1. Read metadata from the bank to determine the source generation rate
bank_sample_rate = self.tap_sr
engine_sample_rate = self.engine_sr
# Alternatively, determine the downsampling factor directly:
exact_ratio = (bank_sample_rate / engine_sample_rate)
decimation_factor = int(np.round(exact_ratio))

if abs(exact_ratio - decimation_factor) > 1e-5 or decimation_factor < 1:
raise ValueError(
f"Multi-rate Error: The bank sample rate ({self.tap_sr} Hz) must be "
f"an exact integer multiple of the engine sample "
f"rate ({self.engine_sr} Hz).\n"
f"Calculated ratio was {exact_ratio:.4f}. Please use standard power-of-2 "
f"downsampling scales (e.g., 2048/512)."
)

# 2. Establish the high-resolution FFT padding length to preserve delta_f
# 512/4096 = 0.125 2048/(4*4096) = 0.125 preserving delta_f
# 4096/4096 = 1 2048/(1/2*4096) = 1 for 4096 engine 2048 bank
high_res_fft_len = self.fir_fft_len * decimation_factor

# Temp allocations for high-resolution processing
high_res_padded = np.zeros((self.batch_size, high_res_fft_len), dtype=np.complex64)

for start in range(0, n_filters, self.batch_size):
end = min(start + self.batch_size, n_filters)
batch_len = end - start

# 1. Zero out and Fill
padded_reshaped[:batch_len, :] = 0
# Zero out processing buffer for next call
high_res_padded[:batch_len, :] = 0.0

# Copy raw 2048 Hz taps into the start of the buffer
tmp_taps = taps[start:end]
padded_reshaped[:batch_len, :n_taps_alloc] = tmp_taps
high_res_padded[:batch_len, :n_taps_alloc] = tmp_taps

# 2. Variable Roll Logic
# 3. Handle Variable Time-Domain Roll Logic at the native 2048 Hz rate
current_counts = counts[start:end]
roll_offsets = -(current_counts // 2)

cols = np.arange(self.fir_fft_len)
cols_high = np.arange(high_res_fft_len)
rows = np.arange(batch_len)[:, None]
shifted_cols = (cols[None, :] - roll_offsets[:, None]) % self.fir_fft_len
shifted_cols_high = (cols_high[None, :] - roll_offsets[:, None]) % high_res_fft_len

current_data = padded_reshaped[:batch_len].copy()
padded_reshaped[:batch_len] = current_data[rows, shifted_cols]
current_data = high_res_padded[:batch_len].copy()
high_res_padded[:batch_len] = current_data[rows, shifted_cols_high]

# 3. Execute FFT (Direct MKL call)
fft_out = self.fft_lib.fft(padded_reshaped[:batch_len], axis=-1)
# 4. Transform to Frequency Domain at native resolution
fft_high_res = self.fft_lib.fft(high_res_padded[:batch_len], axis=-1)

# 5. Brick-Wall Frequency Slicing (Anti-Aliasing & Decimation Match)
# Because the data engine goes up to 256Hz (the 512Hz Nyquist limit), only need the first 4096 bins of that spectrum
fft_sliced = fft_high_res[:batch_len, :self.fir_fft_len]

# 4. Conjugate & Store
filters_f[start:end] = np.conj(fft_out)
# 6. Conjugate & Store back into the 512 Hz buffer block
filters_f[start:end] = np.conj(fft_sliced)

return filters_f

Expand Down
Loading