From b7de77fb50fe34e8b358f77349eee8cb2f1fd15a Mon Sep 17 00:00:00 2001 From: Sahel Iqbal Date: Tue, 26 May 2026 17:29:57 +0100 Subject: [PATCH 1/5] factor out pf step --- cuthbert/smc/particle_filter.py | 84 ++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/cuthbert/smc/particle_filter.py b/cuthbert/smc/particle_filter.py index d445af4..c8096c3 100644 --- a/cuthbert/smc/particle_filter.py +++ b/cuthbert/smc/particle_filter.py @@ -161,6 +161,49 @@ def filter_prepare( ) +def pf_step( + key: KeyArray, + particles: ArrayTree, + log_weights: Array, + log_normalizing_constant: ScalarArray, + model_inputs: ArrayTree, + propagate_sample: PropagateSample, + log_potential: LogPotential, + resampling_fn: Resampling, +): + """Performs a single particle filter step.""" + N = log_weights.shape[0] + keys = random.split(key, N + 1) + + # Resample - resampling_fn is expected to handle adaptivity if desired + ancestor_indices, log_weights, ancestors = resampling_fn( + keys[0], log_weights, particles, N + ) + + # Propagate + next_particles = jax.vmap(propagate_sample, (0, 0, None))( + keys[1:], ancestors, model_inputs + ) + + # Reweight + log_potentials = jax.vmap(log_potential, (0, 0, None))( + ancestors, next_particles, model_inputs + ) + next_log_weights = log_potentials + log_weights + + # Compute the log normalizing constant + logsum_weights = jax.nn.logsumexp(next_log_weights) + log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) + log_normalizing_constant = log_normalizing_constant_incr + log_normalizing_constant + + return { + "particles": next_particles, + "log_weights": next_log_weights, + "ancestor_indices": ancestor_indices, + "log_normalizing_constant": log_normalizing_constant, + } + + def filter_combine( state_1: ParticleFilterState, state_2: ParticleFilterState, @@ -183,37 +226,22 @@ def filter_combine( Returns: The filtered state at the current time step. """ - N = state_1.log_weights.shape[0] - keys = random.split(state_1.key, N + 1) - - # Resample - resampling_fn is expected to handle adaptivity if desired - ancestor_indices, log_weights, ancestors = resampling_fn( - keys[0], state_1.log_weights, state_1.particles, N - ) - - # Propagate - next_particles = jax.vmap(propagate_sample, (0, 0, None))( - keys[1:], ancestors, state_2.model_inputs - ) - - # Reweight - log_potentials = jax.vmap(log_potential, (0, 0, None))( - ancestors, next_particles, state_2.model_inputs - ) - next_log_weights = log_potentials + log_weights - - # Compute the log normalizing constant - logsum_weights = jax.nn.logsumexp(next_log_weights) - log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) - log_normalizing_constant = ( - log_normalizing_constant_incr + state_1.log_normalizing_constant + next_state = pf_step( + key=state_1.key, + particles=state_1.particles, + log_weights=state_1.log_weights, + log_normalizing_constant=state_1.log_normalizing_constant, + model_inputs=state_2.model_inputs, + propagate_sample=propagate_sample, + log_potential=log_potential, + resampling_fn=resampling_fn, ) return ParticleFilterState( state_2.key, - next_particles, - next_log_weights, - ancestor_indices, + next_state["particles"], + next_state["log_weights"], + next_state["ancestor_indices"], state_2.model_inputs, - log_normalizing_constant, + next_state["log_normalizing_constant"], ) From c5560e84ab020465fb5b5e7e14ebe31f5713f710 Mon Sep 17 00:00:00 2001 From: Sahel Iqbal Date: Sun, 31 May 2026 11:40:30 +0100 Subject: [PATCH 2/5] Revert "factor out pf step" This reverts commit b7de77fb50fe34e8b358f77349eee8cb2f1fd15a. --- cuthbert/smc/particle_filter.py | 84 +++++++++++---------------------- 1 file changed, 28 insertions(+), 56 deletions(-) diff --git a/cuthbert/smc/particle_filter.py b/cuthbert/smc/particle_filter.py index c8096c3..d445af4 100644 --- a/cuthbert/smc/particle_filter.py +++ b/cuthbert/smc/particle_filter.py @@ -161,49 +161,6 @@ def filter_prepare( ) -def pf_step( - key: KeyArray, - particles: ArrayTree, - log_weights: Array, - log_normalizing_constant: ScalarArray, - model_inputs: ArrayTree, - propagate_sample: PropagateSample, - log_potential: LogPotential, - resampling_fn: Resampling, -): - """Performs a single particle filter step.""" - N = log_weights.shape[0] - keys = random.split(key, N + 1) - - # Resample - resampling_fn is expected to handle adaptivity if desired - ancestor_indices, log_weights, ancestors = resampling_fn( - keys[0], log_weights, particles, N - ) - - # Propagate - next_particles = jax.vmap(propagate_sample, (0, 0, None))( - keys[1:], ancestors, model_inputs - ) - - # Reweight - log_potentials = jax.vmap(log_potential, (0, 0, None))( - ancestors, next_particles, model_inputs - ) - next_log_weights = log_potentials + log_weights - - # Compute the log normalizing constant - logsum_weights = jax.nn.logsumexp(next_log_weights) - log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) - log_normalizing_constant = log_normalizing_constant_incr + log_normalizing_constant - - return { - "particles": next_particles, - "log_weights": next_log_weights, - "ancestor_indices": ancestor_indices, - "log_normalizing_constant": log_normalizing_constant, - } - - def filter_combine( state_1: ParticleFilterState, state_2: ParticleFilterState, @@ -226,22 +183,37 @@ def filter_combine( Returns: The filtered state at the current time step. """ - next_state = pf_step( - key=state_1.key, - particles=state_1.particles, - log_weights=state_1.log_weights, - log_normalizing_constant=state_1.log_normalizing_constant, - model_inputs=state_2.model_inputs, - propagate_sample=propagate_sample, - log_potential=log_potential, - resampling_fn=resampling_fn, + N = state_1.log_weights.shape[0] + keys = random.split(state_1.key, N + 1) + + # Resample - resampling_fn is expected to handle adaptivity if desired + ancestor_indices, log_weights, ancestors = resampling_fn( + keys[0], state_1.log_weights, state_1.particles, N + ) + + # Propagate + next_particles = jax.vmap(propagate_sample, (0, 0, None))( + keys[1:], ancestors, state_2.model_inputs + ) + + # Reweight + log_potentials = jax.vmap(log_potential, (0, 0, None))( + ancestors, next_particles, state_2.model_inputs + ) + next_log_weights = log_potentials + log_weights + + # Compute the log normalizing constant + logsum_weights = jax.nn.logsumexp(next_log_weights) + log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) + log_normalizing_constant = ( + log_normalizing_constant_incr + state_1.log_normalizing_constant ) return ParticleFilterState( state_2.key, - next_state["particles"], - next_state["log_weights"], - next_state["ancestor_indices"], + next_particles, + next_log_weights, + ancestor_indices, state_2.model_inputs, - next_state["log_normalizing_constant"], + log_normalizing_constant, ) From 60225da2667a9f80a0d837e81abfe2bf60e32133 Mon Sep 17 00:00:00 2001 From: Sahel Iqbal Date: Sun, 31 May 2026 12:14:07 +0100 Subject: [PATCH 3/5] first implementation --- cuthbert/npf/__init__.py | 0 cuthbert/npf/filter.py | 325 +++++++++++++++++++++++++++++++++++++++ cuthbert/npf/types.py | 85 ++++++++++ 3 files changed, 410 insertions(+) create mode 100644 cuthbert/npf/__init__.py create mode 100644 cuthbert/npf/filter.py create mode 100644 cuthbert/npf/types.py diff --git a/cuthbert/npf/__init__.py b/cuthbert/npf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuthbert/npf/filter.py b/cuthbert/npf/filter.py new file mode 100644 index 0000000..fa795b5 --- /dev/null +++ b/cuthbert/npf/filter.py @@ -0,0 +1,325 @@ +"""Implements the nested particle filter of Crisan and Miguez (2018) for parameter estimation. + +Reference: + Crisan and Miguez (2018) - https://doi.org/10.3150/17-BEJ954 +""" + +from functools import partial +from typing import NamedTuple, Protocol + +import jax +import jax.numpy as jnp +from jax import random, tree + +from cuthbert.inference import Filter +from cuthbert.npf.types import ( + JitteringKernel, + LogPotential, + PropagateSample, + SampleParam, +) +from cuthbert.smc.types import InitSample +from cuthbert.utils import dummy_tree_like +from cuthbertlib.resampling import Resampling +from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray, ScalarArray + + +class NPFState(NamedTuple): + """Nested particle filter state. + + Attributes: + key: JAX PRNG key. + param_particles: Parameter particles for the outer filter. + state_particles: State particles for the inner filters. The leading + axes index parameter particles and state particles, respectively. + param_log_weights: Log weights for the parameter particles. + state_log_weights: Log weights for the state particles, with shape + `(n_param_particles, n_state_particles)`. + model_inputs: Model inputs associated with this filter state. + log_normalizing_constant: Current estimate of the log normalizing + constant. + """ + + key: KeyArray + param_particles: ArrayTree + state_particles: ArrayTree + param_log_weights: Array + state_log_weights: Array + model_inputs: ArrayTree + log_normalizing_constant: ScalarArray + + +def build_filter( + init_param_sample: SampleParam, + init_sample: InitSample, + propagate_sample: PropagateSample, + log_potential: LogPotential, + n_param_particles: int, + n_state_particles: int, + resampling_fn: Resampling, + kernel_fn: JitteringKernel, +) -> Filter: + r"""Builds a nested particle filter object. + + Args: + init_param_sample: Function to sample from the initial parameter + distribution $\mu(\theta_0)$. + init_sample: Function to sample from the initial state distribution + $M_0(x_0)$. + propagate_sample: Function to sample from the Markov kernel $M_t(x_t \mid x_{t-1}, \theta)$. + log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t, \theta)$. + n_param_particles: Number of parameter particles for the outer filter. + n_state_particles: Number of state particles for the inner filters. + resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial). + The resampling function may be decorated with adaptive behaviour + (using cuthbertlib.resampling.adaptive.adaptive_resampling_decorator) + before being passed to the filter. The same resampling function is + used for both the outer and inner filters. + kernel_fn: The jittering kernel to use for the parameter particles. + See Section 4.2 in Crisan and Miguez (2018) for different choices. + + Returns: + Filter object for the nested particle filter. + """ + return Filter( + init_prepare=partial( + init_prepare, + init_param_sample=init_param_sample, + init_state_sample=init_sample, + n_param_particles=n_param_particles, + n_state_particles=n_state_particles, + ), + filter_prepare=partial( + filter_prepare, + init_param_sample=init_param_sample, + init_state_sample=init_sample, + n_param_particles=n_param_particles, + n_state_particles=n_state_particles, + ), + filter_combine=partial( + filter_combine, + propagate_sample=propagate_sample, + log_potential=log_potential, + resampling_fn=resampling_fn, + kernel_fn=kernel_fn, + ), + ) + + +def init_prepare( + model_inputs: ArrayTreeLike, + init_param_sample: SampleParam, + init_state_sample: InitSample, + n_param_particles: int, + n_state_particles: int, + key: KeyArray | None = None, +) -> NPFState: + r"""Prepares the initial state for the nested particle filter. + + Args: + model_inputs: Model inputs for the initial state distribution. + init_param_sample: Function to sample from the initial parameter + distribution $\mu(\theta_0)$. + init_state_sample: Function to sample from the initial state + distribution $M_0(x_0)$. + n_param_particles: Number of parameter particles for the outer filter. + n_state_particles: Number of state particles for each inner filter. + key: JAX PRNG key. + + Returns: + Initial nested particle filter state. + + Raises: + ValueError: If `key` is None. + """ + model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs) + if key is None: + raise ValueError("A JAX PRNG key must be provided.") + + param_key, state_key = random.split(key) + + # Sample parameters + param_keys = random.split(param_key, n_param_particles) + param_particles = jax.vmap(init_param_sample)(param_keys) + param_log_weights = jnp.zeros(n_param_particles) + + # Sample states + state_keys = random.split(state_key, n_param_particles * n_state_particles) + state_particles = jax.vmap(init_state_sample, (0, None))(state_keys, model_inputs) + state_log_weights = jnp.zeros((n_param_particles, n_state_particles)) + + return NPFState( + key=key, + param_particles=param_particles, + state_particles=state_particles, + param_log_weights=param_log_weights, + state_log_weights=state_log_weights, + model_inputs=model_inputs, + log_normalizing_constant=jnp.array(0.0), + ) + + +def filter_prepare( + model_inputs: ArrayTreeLike, + init_param_sample: SampleParam, + init_state_sample: InitSample, + n_param_particles: int, + n_state_particles: int, + key: KeyArray | None = None, +) -> NPFState: + r"""Prepares a state for a nested particle filter step. + + Args: + model_inputs: Model inputs for the current filtering step. + init_param_sample: Function to sample from the initial parameter + distribution $\mu(\theta_0)$. + init_state_sample: Function to sample from the initial state + distribution $M_0(x_0)$. + n_param_particles: Number of parameter particles for the outer filter. + n_state_particles: Number of state particles for each inner filter. + key: JAX PRNG key. + + Returns: + Prepared nested particle filter state with placeholder parameter and + state particles. + + Raises: + ValueError: If `key` is None. + """ + model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs) + if key is None: + raise ValueError("A JAX PRNG key must be provided.") + + param_key, state_key = random.split(key) + dummy_param_particle = jax.eval_shape(init_param_sample, param_key) + particles = tree.map( + lambda x: jnp.empty((n_param_particles,) + x.shape), dummy_param_particle + ) + param_particles = dummy_tree_like(particles) + + dummy_state_particle = jax.eval_shape(init_state_sample, state_key, model_inputs) + state_particles = tree.map( + lambda x: jnp.empty((n_param_particles, n_state_particles) + x.shape), + dummy_state_particle, + ) + + return NPFState( + key=key, + param_particles=param_particles, + state_particles=state_particles, + param_log_weights=jnp.zeros(n_param_particles), + state_log_weights=jnp.zeros((n_param_particles, n_state_particles)), + model_inputs=model_inputs, + log_normalizing_constant=jnp.array(0.0), + ) + + +def _pf_step( + key: KeyArray, + particles: ArrayTree, + log_weights: Array, + param: ArrayTree, + model_inputs: ArrayTree, + propagate_sample: PropagateSample, + log_potential: LogPotential, + resampling_fn: Resampling, +): + """Performs a single particle filter step.""" + N = log_weights.shape[0] + keys = random.split(key, N + 1) + + # Resample - resampling_fn is expected to handle adaptivity if desired + _, log_weights, ancestors = resampling_fn(keys[0], log_weights, particles, N) + + # Propagate + next_particles = jax.vmap(propagate_sample, (0, 0, None, None))( + keys[1:], ancestors, param, model_inputs + ) + + # Reweight + log_potentials = jax.vmap(log_potential, (0, 0, None, None))( + ancestors, next_particles, param, model_inputs + ) + next_log_weights = log_potentials + log_weights + + # Compute the log normalizing constant + logsum_weights = jax.nn.logsumexp(next_log_weights) + log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) + + return { + "particles": next_particles, + "log_weights": next_log_weights, + "log_normalizing_constant_incr": log_normalizing_constant_incr, + } + + +def filter_combine( + state_1: NPFState, + state_2: NPFState, + propagate_sample: PropagateSample, + log_potential: LogPotential, + resampling_fn: Resampling, + kernel_fn: JitteringKernel, +) -> NPFState: + r"""Combines previous filter state with the state prepared for the current step. + + See Algorithm 3 in Crisan and Miguez (2018) for details. + + Args: + state_1: Nested particle filter state from the previous time step. + state_2: Nested particle filter state prepared for the current time + step. + propagate_sample: Function to sample from the state Markov kernel + $M_t(x_t \mid x_{t-1}, \theta)$. + log_potential: Function to compute the log potential + $\log G_t(x_{t-1}, x_t, \theta)$. + resampling_fn: Resampling algorithm to use for the outer and inner + filters. + kernel_fn: Jittering kernel applied to resampled parameter particles. + + Returns: + Nested particle filter state for the current time step. + """ + N, M = state_1.state_log_weights.shape + + # Resample + key, sub_key = random.split(state_1.key) + _, log_weights, ancestors = resampling_fn( + sub_key, state_1.param_log_weights, state_1.param_particles, N + ) + + # Jitter + keys = random.split(key, N + 1) + # TODO: We should only jitter if the particles were resampled. + param_particles = jax.vmap(kernel_fn)(keys[1:], ancestors) + + # Perform the inner particle filter step for each parameter particle + keys = random.split(keys[0], N) + next_inner_state = jax.vmap(_pf_step, (0, 0, 0, 0, None, None, None, None))( + keys, + state_1.state_particles, + state_1.state_log_weights, + param_particles, + state_2.model_inputs, + propagate_sample, + log_potential, + resampling_fn, + ) + + # The log potentials are the log normalizing constants increments from the inner particle filters + next_log_weights = log_weights + next_inner_state["log_normalizing_constant_incr"] + logsum_weights = jax.nn.logsumexp(next_log_weights) + log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights) + log_normalizing_constant = ( + log_normalizing_constant_incr + state_1.log_normalizing_constant + ) + + return NPFState( + key=state_2.key, + param_particles=ancestors, + state_particles=next_inner_state["particles"], + param_log_weights=next_log_weights, + state_log_weights=next_inner_state["log_weights"], + model_inputs=state_2.model_inputs, + log_normalizing_constant=log_normalizing_constant, + ) diff --git a/cuthbert/npf/types.py b/cuthbert/npf/types.py new file mode 100644 index 0000000..99d7c18 --- /dev/null +++ b/cuthbert/npf/types.py @@ -0,0 +1,85 @@ +"""Provides types for the nested particle filter.""" + +from typing import Protocol + +from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray + + +class SampleParam(Protocol): + r"""Protocol for sampling from the initial distribution $\mu(\theta_0)$.""" + + def __call__(self, key: KeyArray) -> ArrayTree: + r"""Samples from the initial distribution $\mu(\theta_0)$. + + Args: + key: JAX PRNG key. + + Returns: + A sample $\theta_0$. + """ + ... + + +class PropagateSample(Protocol): + r"""Protocol for sampling from the Markov kernel $M_t(x_t \mid x_{t-1}, \theta)$.""" + + def __call__( + self, + key: KeyArray, + state: ArrayTreeLike, + param: ArrayTreeLike, + model_inputs: ArrayTreeLike, + ) -> ArrayTree: + r"""Samples from the Markov kernel $M_t(x_t \mid x_{t-1}, \theta)$. + + Args: + key: JAX PRNG key. + state: State at the previous step $x_{t-1}$. + param: Hidden parameter $\theta$. + model_inputs: Model inputs. + + Returns: + A sample $x_t$. + """ + ... + + +class LogPotential(Protocol): + r"""Protocol for computing the log potential function $\log G_t(x_{t-1}, x_t, \theta)$.""" + + def __call__( + self, + state_prev: ArrayTreeLike, + state: ArrayTreeLike, + param: ArrayTreeLike, + model_inputs: ArrayTreeLike, + ) -> ScalarArray: + r"""Computes the log potential function $\log G_t(x_{t-1}, x_t, \theta)$. + + Args: + state_prev: State at the previous step $x_{t-1}$. + state: State at the current step $x_{t}$. + param: Hidden parameter $\theta$. + model_inputs: Model inputs. + + Returns: + A scalar value $\log G_t(x_{t-1}, x_t, \theta)$. + """ + ... + + +class JitteringKernel(Protocol): + """Protocol for a jittering kernel used in the nested particle filter.""" + + def __call__(self, key: KeyArray, particle: ArrayTree, **kwargs) -> ArrayTree: + """Applies the jittering kernel to a particle. + + Args: + key: JAX PRNG key. + particle: Particle to be jittered. + kwargs: Additional arguments for the jittering kernel. + + Returns: + Jittered particle. + """ + ... From 21ea1ab41f532bd287120f86953c4426585acf04 Mon Sep 17 00:00:00 2001 From: Sahel Iqbal Date: Sun, 31 May 2026 12:43:02 +0100 Subject: [PATCH 4/5] fix some bugs --- cuthbert/npf/__init__.py | 1 + cuthbert/npf/filter.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cuthbert/npf/__init__.py b/cuthbert/npf/__init__.py index e69de29..656bc80 100644 --- a/cuthbert/npf/__init__.py +++ b/cuthbert/npf/__init__.py @@ -0,0 +1 @@ +from cuthbert.npf.filter import build_filter diff --git a/cuthbert/npf/filter.py b/cuthbert/npf/filter.py index fa795b5..be73a00 100644 --- a/cuthbert/npf/filter.py +++ b/cuthbert/npf/filter.py @@ -146,6 +146,10 @@ def init_prepare( # Sample states state_keys = random.split(state_key, n_param_particles * n_state_particles) state_particles = jax.vmap(init_state_sample, (0, None))(state_keys, model_inputs) + state_particles = tree.map( + lambda x: x.reshape((n_param_particles, n_state_particles) + x.shape[1:]), + state_particles, + ) state_log_weights = jnp.zeros((n_param_particles, n_state_particles)) return NPFState( @@ -202,6 +206,7 @@ def filter_prepare( lambda x: jnp.empty((n_param_particles, n_state_particles) + x.shape), dummy_state_particle, ) + state_particles = dummy_tree_like(state_particles) return NPFState( key=key, @@ -263,7 +268,8 @@ def filter_combine( ) -> NPFState: r"""Combines previous filter state with the state prepared for the current step. - See Algorithm 3 in Crisan and Miguez (2018) for details. + This is Algorithm 3 from Crisan and Miguez (2018) with one difference: we + perform resampling before jittering and return weighted particles. Args: state_1: Nested particle filter state from the previous time step. @@ -284,9 +290,13 @@ def filter_combine( # Resample key, sub_key = random.split(state_1.key) - _, log_weights, ancestors = resampling_fn( + ancestor_indices, log_weights, ancestors = resampling_fn( sub_key, state_1.param_log_weights, state_1.param_particles, N ) + state_particles, state_log_weights = tree.map( + lambda x: x[ancestor_indices], + (state_1.state_particles, state_1.state_log_weights), + ) # Jitter keys = random.split(key, N + 1) @@ -297,8 +307,8 @@ def filter_combine( keys = random.split(keys[0], N) next_inner_state = jax.vmap(_pf_step, (0, 0, 0, 0, None, None, None, None))( keys, - state_1.state_particles, - state_1.state_log_weights, + state_particles, + state_log_weights, param_particles, state_2.model_inputs, propagate_sample, @@ -316,7 +326,7 @@ def filter_combine( return NPFState( key=state_2.key, - param_particles=ancestors, + param_particles=param_particles, state_particles=next_inner_state["particles"], param_log_weights=next_log_weights, state_log_weights=next_inner_state["log_weights"], From cd80051bf769a993722af3daf48f356b4f90afb0 Mon Sep 17 00:00:00 2001 From: Sahel Iqbal Date: Sun, 31 May 2026 12:52:50 +0100 Subject: [PATCH 5/5] add npf smoke test --- tests/cuthbert/npf/__init__.py | 1 + tests/cuthbert/npf/test_filter.py | 84 +++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 tests/cuthbert/npf/__init__.py create mode 100644 tests/cuthbert/npf/test_filter.py diff --git a/tests/cuthbert/npf/__init__.py b/tests/cuthbert/npf/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/cuthbert/npf/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/cuthbert/npf/test_filter.py b/tests/cuthbert/npf/test_filter.py new file mode 100644 index 0000000..cc98c71 --- /dev/null +++ b/tests/cuthbert/npf/test_filter.py @@ -0,0 +1,84 @@ +import chex +import jax.numpy as jnp +from jax import random + +from cuthbert import filter +from cuthbert.npf import build_filter +from cuthbertlib.resampling import systematic + + +class TestNestedParticleFilter(chex.TestCase): + @chex.variants(with_jit=True, without_jit=True) + def test_filter_runs_and_returns_expected_shapes(self): + n_param_particles = 5 + n_state_particles = 7 + param_dim = 2 + state_dim = 2 + num_time_steps = 3 + + def init_param_sample(key): + return random.normal(key, (param_dim,)) + + def init_sample(key, model_inputs): + return jnp.full((state_dim,), model_inputs) + 0.1 * random.normal( + key, (state_dim,) + ) + + def propagate_sample(key, state, param, model_inputs): + noise = 0.1 * random.normal(key, (state_dim,)) + return 0.8 * state + 0.2 * param + model_inputs + noise + + def log_potential(state_prev, state, param, model_inputs): + del state_prev, param + residual = model_inputs - state + return -0.5 * jnp.sum(residual**2) + + def kernel_fn(key, particle): + return particle + 0.01 * random.normal(key, particle.shape) + + filter_obj = build_filter( + init_param_sample=init_param_sample, + init_sample=init_sample, + propagate_sample=propagate_sample, + log_potential=log_potential, + n_param_particles=n_param_particles, + n_state_particles=n_state_particles, + resampling_fn=systematic.resampling, + kernel_fn=kernel_fn, + ) + + model_inputs = jnp.arange(num_time_steps + 1, dtype=jnp.float32) + + states = self.variant(filter, static_argnames=("filter_obj", "parallel"))( + filter_obj, model_inputs, parallel=False, key=random.key(0) + ) + + expected_time_shape = (num_time_steps + 1,) + chex.assert_shape(states.key, expected_time_shape) + chex.assert_shape( + states.param_particles, + expected_time_shape + (n_param_particles, param_dim), + ) + chex.assert_shape( + states.state_particles, + expected_time_shape + (n_param_particles, n_state_particles, state_dim), + ) + chex.assert_shape( + states.param_log_weights, expected_time_shape + (n_param_particles,) + ) + chex.assert_shape( + states.state_log_weights, + expected_time_shape + (n_param_particles, n_state_particles), + ) + chex.assert_shape(states.model_inputs, expected_time_shape) + chex.assert_shape(states.log_normalizing_constant, expected_time_shape) + chex.assert_tree_all_finite( + ( + states.param_particles, + states.state_particles, + states.param_log_weights, + states.state_log_weights, + states.model_inputs, + states.log_normalizing_constant, + ) + )