Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cuthbert/factorial/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def filter(
init_factorial_state = filter_obj.init_prepare(
init_model_input, key=prepare_keys[0]
)
init_factorial_state = factorializer.factorize_initial_state(init_factorial_state)
init_factorial_state = factorializer.factorialize_init_state(
init_factorial_state, init_model_input
)

prep_model_inputs = tree.map(lambda x: x[1:], model_inputs)

Expand Down
72 changes: 36 additions & 36 deletions cuthbert/factorial/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,13 @@
from cuthbert.smc.marginal_particle_filter import MarginalParticleFilterState
from cuthbert.smc.particle_filter import ParticleFilterState
from cuthbertlib.resampling import Resampling, ess_decorator
from cuthbertlib.types import Array, ArrayLike
from cuthbertlib.types import Array, ArrayLike, ArrayTreeLike

GeneralParticleFilterState = TypeVar(
"GeneralParticleFilterState", ParticleFilterState, MarginalParticleFilterState
)


class SMCFactorializer(Factorializer):
"""Factorializer for particle filter states."""

def factorize_initial_state(
self, initial_state: GeneralParticleFilterState
) -> GeneralParticleFilterState:
"""Convert initial SMC state particles from `(N, F, ...)` to `(F, N, ...)`.

Generic SMC filters sample initial particles with a leading particle axis.
The factorial SMC machinery expects the factor axis to lead instead.
Initial weights and particle filter ancestor indices are broadcast from
`(N,)` to `(F, N)`, matching the factorial SMC state layout.
"""
particles = tree.map(lambda x: jnp.moveaxis(x, 0, 1), initial_state.particles)
n_factors = tree.leaves(particles)[0].shape[0]
n_particles = initial_state.log_weights.shape[0]

new_state = initial_state._replace(
particles=particles,
log_weights=jnp.broadcast_to(
initial_state.log_weights, (n_factors, n_particles)
),
)

if isinstance(initial_state, ParticleFilterState):
new_state = new_state._replace(
ancestor_indices=jnp.broadcast_to(
initial_state.ancestor_indices, (n_factors, n_particles)
)
)

return new_state


def build_factorializer(
get_factorial_indices: GetFactorialIndices,
resampling_fn: Resampling,
Expand All @@ -77,14 +43,48 @@ def build_factorializer(
Returns:
Factorializer for SMC states with extract, join, marginalize, and insert.
"""
return SMCFactorializer(
return Factorializer(
get_factorial_indices=get_factorial_indices,
extract=extract,
join=lambda local_factorial_state: join(local_factorial_state, resampling_fn),
marginalize=marginalize,
insert=insert,
factorialize_init_state=factorialize_init_state,
)


def factorialize_init_state(
init_state: GeneralParticleFilterState, model_inputs: ArrayTreeLike
) -> GeneralParticleFilterState:
"""Convert initial SMC state particles from `(N, F, ...)` to `(F, N, ...)`.

Generic SMC filters sample initial particles with a leading particle axis.
The factorial SMC machinery expects the factor axis to lead instead.
Initial weights and particle filter ancestor indices are broadcast from
`(N,)` to `(F, N)`, matching the factorial SMC state layout.

Args:
init_state: Output from particle filter `init_prepare`
model_inputs: The model inputs at the first time point - unused.
"""
particles = tree.map(lambda x: jnp.moveaxis(x, 0, 1), init_state.particles)
n_factors = tree.leaves(particles)[0].shape[0]
n_particles = init_state.log_weights.shape[0]

new_state = init_state._replace(
particles=particles,
log_weights=jnp.broadcast_to(init_state.log_weights, (n_factors, n_particles)),
)

if isinstance(init_state, ParticleFilterState):
new_state = new_state._replace(
ancestor_indices=jnp.broadcast_to(
init_state.ancestor_indices, (n_factors, n_particles)
)
)

return new_state


def extract(
factorial_state: GeneralParticleFilterState,
Expand Down
35 changes: 26 additions & 9 deletions cuthbert/factorial/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ def __call__(
...


class FactorializeInitState(Protocol):
"""Protocol for factorial post-processing of `init_prepare`."""

def __call__(
self, init_state: ArrayTreeLike, model_inputs: ArrayTreeLike
) -> ArrayTree:
"""Any processing of the output of `init_prepare` for factorial inference.

Args:
init_state: Output from base inference method's `init_prepare`
model_inputs: The model inputs at the first time point.
"""
...


class Factorializer(NamedTuple):
"""Factorializer object.

Expand All @@ -152,22 +167,24 @@ class Factorializer(NamedTuple):
join: Function to combine factorial states into a joint state.
marginalize: Function to marginalize a joint state into a factored state.
insert: Function to insert a local factorial state into a factorial state.
factorialize_init_state: Optional post-processing function to `init_prepare`.
By default leaves the output of `init_prepare` unchanged.
extract_and_join: Apply extract and then join.
Input: Global factorial state.
Output: Local joint state.
marginalize_and_insert: Apply marginalize and then insert.
Input: Local joint state and global factorial state.
Output: Global factorial state.
"""

get_factorial_indices: GetFactorialIndices
extract: Extract
join: Join
marginalize: Marginalize
insert: Insert

def factorize_initial_state(self, initial_state: ArrayTreeLike) -> ArrayTreeLike:
"""Convert an initial filter state to factorial layout if needed.

Most inference methods already construct initial states in factorial
layout. Inference-specific factorializers can override this when their
generic initial state layout differs from the factorial convention.
"""
return initial_state
factorialize_init_state: FactorializeInitState = lambda init_state, model_inputs: (
init_state
)

def extract_and_join(
self, factorial_state: ArrayTreeLike, model_inputs: ArrayTreeLike
Expand Down