diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py index c188113..b325434 100644 --- a/cuthbert/factorial/filtering.py +++ b/cuthbert/factorial/filtering.py @@ -58,7 +58,9 @@ def filter( init_factorial_state = filter_obj.init_prepare( init_model_input, key=prepare_keys[0] ) - init_factorial_state = factorializer.factorize_initial_state(init_factorial_state) + init_factorial_state = factorializer.factorialize_init_state( + init_factorial_state, init_model_input + ) prep_model_inputs = tree.map(lambda x: x[1:], model_inputs) diff --git a/cuthbert/factorial/smc.py b/cuthbert/factorial/smc.py index dfd277c..285da82 100644 --- a/cuthbert/factorial/smc.py +++ b/cuthbert/factorial/smc.py @@ -10,47 +10,13 @@ from cuthbert.smc.marginal_particle_filter import MarginalParticleFilterState from cuthbert.smc.particle_filter import ParticleFilterState from cuthbertlib.resampling import Resampling, ess_decorator -from cuthbertlib.types import Array, ArrayLike +from cuthbertlib.types import Array, ArrayLike, ArrayTreeLike GeneralParticleFilterState = TypeVar( "GeneralParticleFilterState", ParticleFilterState, MarginalParticleFilterState ) -class SMCFactorializer(Factorializer): - """Factorializer for particle filter states.""" - - def factorize_initial_state( - self, initial_state: GeneralParticleFilterState - ) -> GeneralParticleFilterState: - """Convert initial SMC state particles from `(N, F, ...)` to `(F, N, ...)`. - - Generic SMC filters sample initial particles with a leading particle axis. - The factorial SMC machinery expects the factor axis to lead instead. - Initial weights and particle filter ancestor indices are broadcast from - `(N,)` to `(F, N)`, matching the factorial SMC state layout. - """ - particles = tree.map(lambda x: jnp.moveaxis(x, 0, 1), initial_state.particles) - n_factors = tree.leaves(particles)[0].shape[0] - n_particles = initial_state.log_weights.shape[0] - - new_state = initial_state._replace( - particles=particles, - log_weights=jnp.broadcast_to( - initial_state.log_weights, (n_factors, n_particles) - ), - ) - - if isinstance(initial_state, ParticleFilterState): - new_state = new_state._replace( - ancestor_indices=jnp.broadcast_to( - initial_state.ancestor_indices, (n_factors, n_particles) - ) - ) - - return new_state - - def build_factorializer( get_factorial_indices: GetFactorialIndices, resampling_fn: Resampling, @@ -77,14 +43,48 @@ def build_factorializer( Returns: Factorializer for SMC states with extract, join, marginalize, and insert. """ - return SMCFactorializer( + return Factorializer( get_factorial_indices=get_factorial_indices, extract=extract, join=lambda local_factorial_state: join(local_factorial_state, resampling_fn), marginalize=marginalize, insert=insert, + factorialize_init_state=factorialize_init_state, + ) + + +def factorialize_init_state( + init_state: GeneralParticleFilterState, model_inputs: ArrayTreeLike +) -> GeneralParticleFilterState: + """Convert initial SMC state particles from `(N, F, ...)` to `(F, N, ...)`. + + Generic SMC filters sample initial particles with a leading particle axis. + The factorial SMC machinery expects the factor axis to lead instead. + Initial weights and particle filter ancestor indices are broadcast from + `(N,)` to `(F, N)`, matching the factorial SMC state layout. + + Args: + init_state: Output from particle filter `init_prepare` + model_inputs: The model inputs at the first time point - unused. + """ + particles = tree.map(lambda x: jnp.moveaxis(x, 0, 1), init_state.particles) + n_factors = tree.leaves(particles)[0].shape[0] + n_particles = init_state.log_weights.shape[0] + + new_state = init_state._replace( + particles=particles, + log_weights=jnp.broadcast_to(init_state.log_weights, (n_factors, n_particles)), ) + if isinstance(init_state, ParticleFilterState): + new_state = new_state._replace( + ancestor_indices=jnp.broadcast_to( + init_state.ancestor_indices, (n_factors, n_particles) + ) + ) + + return new_state + def extract( factorial_state: GeneralParticleFilterState, diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py index 731741f..ea67b49 100644 --- a/cuthbert/factorial/types.py +++ b/cuthbert/factorial/types.py @@ -140,6 +140,21 @@ def __call__( ... +class FactorializeInitState(Protocol): + """Protocol for factorial post-processing of `init_prepare`.""" + + def __call__( + self, init_state: ArrayTreeLike, model_inputs: ArrayTreeLike + ) -> ArrayTree: + """Any processing of the output of `init_prepare` for factorial inference. + + Args: + init_state: Output from base inference method's `init_prepare` + model_inputs: The model inputs at the first time point. + """ + ... + + class Factorializer(NamedTuple): """Factorializer object. @@ -152,6 +167,14 @@ class Factorializer(NamedTuple): join: Function to combine factorial states into a joint state. marginalize: Function to marginalize a joint state into a factored state. insert: Function to insert a local factorial state into a factorial state. + factorialize_init_state: Optional post-processing function to `init_prepare`. + By default leaves the output of `init_prepare` unchanged. + extract_and_join: Apply extract and then join. + Input: Global factorial state. + Output: Local joint state. + marginalize_and_insert: Apply marginalize and then insert. + Input: Local joint state and global factorial state. + Output: Global factorial state. """ get_factorial_indices: GetFactorialIndices @@ -159,15 +182,9 @@ class Factorializer(NamedTuple): join: Join marginalize: Marginalize insert: Insert - - def factorize_initial_state(self, initial_state: ArrayTreeLike) -> ArrayTreeLike: - """Convert an initial filter state to factorial layout if needed. - - Most inference methods already construct initial states in factorial - layout. Inference-specific factorializers can override this when their - generic initial state layout differs from the factorial convention. - """ - return initial_state + factorialize_init_state: FactorializeInitState = lambda init_state, model_inputs: ( + init_state + ) def extract_and_join( self, factorial_state: ArrayTreeLike, model_inputs: ArrayTreeLike