diff --git a/Cargo.toml b/Cargo.toml index 985bf3a..70f2fdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,8 +41,11 @@ pyo3-object_store = "0.9.0" zarrs = { version = "0.23.2", features = ["async"] } zarrs_object_store = "0.6.0" tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } -pyo3-arrow = "0.16.0" -arrow = { version = "57.0.0", features = ["json"] } +pyo3-arrow = "0.17.0" +arrow = { version = "58.0.0", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +pythonize = "0.28.0" [dependencies.pyo3] version = "0.28.0" diff --git a/pyproject.toml b/pyproject.toml index ad8a07c..2f4dc5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyarrow >= 12.0.0", "arro3-core >= 0.6.0", "pandas >= 2.0", + "platformdirs >= 3.0.0", "xarray >= 2025.01.2", "arviz >= 0.20.0,<1.0", "obstore >= 0.8.0", diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index d77063f..22f9164 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,7 +1,7 @@ from nutpie import _lib from nutpie._lib import store as zarr_store from nutpie.compile_pymc import compile_pymc_model -from nutpie.compile_stan import compile_stan_model +from nutpie.compile_stan import compile_stan_model, prune_stan_cache from nutpie.sample import sample ChainProgress = _lib.PyChainProgress @@ -12,6 +12,7 @@ "ChainProgress", "compile_pymc_model", "compile_stan_model", + "prune_stan_cache", "sample", "zarr_store", ] diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 14cc13a..7240e74 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,3 +1,7 @@ +import datetime +import hashlib +import json +import shutil import tempfile from dataclasses import dataclass, replace from importlib.util import find_spec @@ -144,6 +148,105 @@ def coords(self): return self._coords +def _stan_cache_key( + code: str, + extra_compile_args: Optional[list[str]], + extra_stanc_args: Optional[list[str]], +) -> str: + """Return a SHA-256 hex digest identifying a unique compilation job.""" + import bridgestan + + fingerprint = json.dumps( + { + "code": code, + "extra_compile_args": sorted(extra_compile_args or []), + "extra_stanc_args": sorted(extra_stanc_args or []), + "bridgestan_version": bridgestan.__version__, + }, + sort_keys=True, + ) + return hashlib.sha256(fingerprint.encode()).hexdigest() + + +def _stan_cache_dir() -> Path: + """Return (and create) the directory where compiled Stan models are cached.""" + import platformdirs + + cache_dir = Path(platformdirs.user_cache_dir("nutpie")) / "stan" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def prune_stan_cache( + max_entries: int = 16, + min_age: datetime.timedelta = datetime.timedelta(weeks=2), +) -> None: + """Remove old entries from the Stan compilation cache. + + Entries are only considered for removal if they are older than *min_age*. + Among those, the oldest ones are removed until at most *max_entries* + entries remain. + + Parameters + ---------- + max_entries: + Maximum number of cache entries to keep. Defaults to 16. + min_age: + Entries younger than this are never removed, regardless of how many + entries exist. Defaults to 2 weeks. + """ + cache_dir = _stan_cache_dir() + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Collect all valid (marker exists) entries with their mtime. + entries = [] + for entry_dir in cache_dir.iterdir(): + if not entry_dir.is_dir(): + continue + marker = entry_dir / "ok" + if not marker.exists(): + continue + mtime = datetime.datetime.fromtimestamp( + marker.stat().st_mtime, tz=datetime.timezone.utc + ) + entries.append((mtime, entry_dir)) + + if len(entries) <= max_entries: + return + + # Only entries older than min_age are candidates for eviction. + candidates = sorted( + [(mtime, d) for mtime, d in entries if (now - mtime) >= min_age] + ) + + n_to_remove = len(entries) - max_entries + for _, entry_dir in candidates[:n_to_remove]: + shutil.rmtree(entry_dir, ignore_errors=True) + + +def _compile_stan_model( + model_name: str, + code: str, + build_dir: Path, + make_args: list[str], + stanc_args: list[str], +) -> Path: + """Write *code* into *build_dir*, compile it, and return the path to the shared library.""" + import bridgestan + + model_path = ( + build_dir.joinpath("name") + .with_name(model_name) # This verifies that it is a valid filename + .with_suffix(".stan") + ) + model_path.write_text(code) + so_path = bridgestan.compile_model( + model_path, make_args=make_args, stanc_args=stanc_args + ) + bridgestan.compile.windows_dll_path_setup() + return so_path + + def compile_stan_model( *, code: Optional[str] = None, @@ -154,7 +257,46 @@ def compile_stan_model( coords: Optional[dict[str, Any]] = None, model_name: Optional[str] = None, cleanup: bool = True, + cache: bool = False, + prune_cache: bool = True, ) -> CompiledStanModel: + """Compile a Stan model and return a :class:`CompiledStanModel`. + + Parameters + ---------- + code: + Stan model source code as a string. + filename: + Path to a ``.stan`` file. Mutually exclusive with *code*. + extra_compile_args: + Extra arguments forwarded to the C++ compiler via BridgeStan's + ``make_args``. + extra_stanc_args: + Extra arguments forwarded to the Stan compiler (``stanc``). + dims: + Variable dimension names, e.g. ``{"alpha": ["county"]}``. + coords: + Coordinate labels for each dimension, e.g. + ``{"county": ["Hennepin", "Ramsey", ...]}``. + model_name: + Base name used for the ``.stan`` file. Defaults to ``"model"``. + cleanup: + Remove the temporary build directory after compilation. Has no + effect when *cache* is ``True`` (the build directory is the cache + entry and is never deleted). + cache: + When ``True``, compile the model into a persistent directory under + the user cache directory (``~/.cache/nutpie/stan`` on Linux/macOS, + ``%LOCALAPPDATA%\\nutpie\\stan`` on Windows) and reuse it on + subsequent calls with identical arguments and the same BridgeStan + version. A marker file ``ok`` is written only after a successful + build, so interrupted or failed compilations are never reused. + Defaults to ``False``. + prune_cache: + When ``True`` (the default), call :func:`prune_stan_cache` after + each new compilation to evict old cache entries. Has no effect + when *cache* is ``False``. + """ if find_spec("bridgestan") is None: raise ImportError( "BridgeStan is not installed in the current environment. " @@ -180,33 +322,58 @@ def compile_stan_model( if model_name is None: model_name = "model" - basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) - try: - model_path = ( - Path(basedir.name) - .joinpath("name") - .with_name(model_name) # This verifies that it is a valid filename - .with_suffix(".stan") - ) - model_path.write_text(code) - make_args = ["STAN_THREADS=true"] - if extra_compile_args: - make_args.extend(extra_compile_args) - stanc_args = [] - if extra_stanc_args: - stanc_args.extend(extra_stanc_args) - so_path = bridgestan.compile_model( - model_path, make_args=make_args, stanc_args=stanc_args - ) - # Set necessary library loading paths - bridgestan.compile.windows_dll_path_setup() - library = _lib.StanLibrary(so_path) - finally: + make_args = ["STAN_THREADS=true"] + if extra_compile_args: + make_args.extend(extra_compile_args) + stanc_args = [] + if extra_stanc_args: + stanc_args.extend(extra_stanc_args) + + if cache: + digest = _stan_cache_key(code, extra_compile_args, extra_stanc_args) + entry_dir = _stan_cache_dir() / digest + marker = entry_dir / "ok" + + so_path_file = entry_dir / "so_path.txt" + + if marker.exists(): + # Cache hit: touch the marker to record recent use, then load. + marker.touch() + so_path = Path(so_path_file.read_text()) + if not so_path.exists(): + raise FileNotFoundError( + f"Cached Stan library not found: {so_path}. " + "The cache entry may be corrupt; delete it and recompile." + ) + bridgestan.compile.windows_dll_path_setup() + library = _lib.StanLibrary(str(so_path)) + else: + # Cache miss: compile directly into the cache entry directory so + # that all relative loading paths inside the .so remain valid. + entry_dir.mkdir(parents=True, exist_ok=True) + so_path = _compile_stan_model( + model_name, code, entry_dir, make_args, stanc_args + ) + # Write the .so path before the marker so the marker is only + # ever present once so_path.txt is fully written. + so_path_file.write_text(str(so_path)) + marker.write_text("") + library = _lib.StanLibrary(str(so_path)) + if prune_cache: + prune_stan_cache() + else: + basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) try: - if cleanup: - basedir.cleanup() - except Exception: # noqa: BLE001 - pass + so_path = _compile_stan_model( + model_name, code, Path(basedir.name), make_args, stanc_args + ) + library = _lib.StanLibrary(str(so_path)) + finally: + try: + if cleanup: + basedir.cleanup() + except Exception: # noqa: BLE001 + pass return CompiledStanModel( code=code, diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index d52d93b..fcb222b 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -8,7 +8,7 @@ import pandas as pd import pyarrow -from nutpie import _lib # type: ignore +from nutpie import _lib @dataclass(frozen=True) @@ -135,6 +135,8 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va if name not in data_dict: if dtype in [np.float64, np.float32]: data = np.full(total_shape, np.nan, dtype=dtype) + elif dtype == np.dtype("O"): + data = np.full(total_shape, None, dtype=dtype) else: data = np.zeros(total_shape, dtype=dtype) data_dict[name] = data @@ -148,6 +150,8 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va ) else: is_null = is_null.to_numpy(False) + if values.shape[0] == num_draws: + values = values[~is_null] data_dict[name][chain, :num_draws][~is_null] = values.reshape( ((~is_null).sum(),) + tuple(item_shape) ) @@ -297,7 +301,7 @@ def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_va def in_marimo_notebook() -> bool: try: - import marimo as mo + import marimo as mo # ty:ignore[unresolved-import] return mo.running_in_notebook() except ImportError: @@ -306,17 +310,25 @@ def in_marimo_notebook() -> bool: def _mo_write_internal(cell_id, stream, value: object) -> None: """Write to marimo cell given cell_id and stream.""" - import marimo + import marimo # ty:ignore[unresolved-import] if marimo.__version__ < "0.19.0": # The old CellOp API is identical to new CellNotificationUtils - from marimo._messaging.ops import CellOp as CellNotificationUtils + from marimo._messaging.ops import ( # ty:ignore[unresolved-import] + CellOp as CellNotificationUtils, + ) else: - from marimo._messaging.notification_utils import CellNotificationUtils + from marimo._messaging.notification_utils import ( # ty:ignore[unresolved-import] + CellNotificationUtils, + ) - from marimo._messaging.cell_output import CellChannel - from marimo._messaging.tracebacks import write_traceback - from marimo._output import formatting + from marimo._messaging.cell_output import ( # ty:ignore[unresolved-import] + CellChannel, + ) + from marimo._messaging.tracebacks import ( # ty:ignore[unresolved-import] + write_traceback, + ) + from marimo._output import formatting # ty:ignore[unresolved-import] output = formatting.try_format(value) if output.traceback is not None: @@ -333,9 +345,11 @@ def _mo_write_internal(cell_id, stream, value: object) -> None: def _mo_create_replace(): """Create mo.output.replace with current context pinned.""" - from marimo._output import formatting - from marimo._runtime.context import get_context - from marimo._runtime.context.types import ContextNotInitializedError + from marimo._output import formatting # ty:ignore[unresolved-import] + from marimo._runtime.context import get_context # ty:ignore[unresolved-import] + from marimo._runtime.context.types import ( # ty:ignore[unresolved-import] + ContextNotInitializedError, + ) try: ctx = get_context() @@ -359,7 +373,7 @@ def in_notebook(): def in_colab(): "Check if the code is running in Google Colaboratory" try: - from google import colab # noqa: F401 + from google import colab # noqa: F401 # ty:ignore[unresolved-import] return True except ImportError: @@ -371,7 +385,7 @@ def in_colab(): shell = get_ipython().__class__.__name__ # type: ignore if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole try: - from IPython.display import ( + from IPython.display import ( # ty:ignore[unresolved-import] HTML, # noqa: F401 clear_output, # noqa: F401 display, # noqa: F401 @@ -457,7 +471,7 @@ def __init__( if progress_style is None: progress_style = _progress_style - import IPython + import IPython # ty:ignore[unresolved-import] self._html = "" @@ -483,7 +497,7 @@ def callback(formatted): progress_rate, progress_template, cores, callback ) elif in_marimo_notebook(): - import marimo as mo + import marimo as mo # ty:ignore[unresolved-import] if progress_template is None: progress_template = _progress_template @@ -532,6 +546,7 @@ def wait(self, *, timeout=None): return self._extract(results) def _extract(self, results): + settings_dict = self._settings.as_dict() if self._return_raw_trace: return results else: @@ -548,7 +563,7 @@ def _extract(self, results): store = cls(*args, **kwargs) obj_store = ObjectStore(store, read_only=True) - ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) + ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) # ty:ignore[invalid-argument-type] return arviz.from_datatree(ds) elif results.is_arrow(): @@ -556,7 +571,7 @@ def _extract(self, results): skips = { "store_gradient": ["gradient"], "store_unconstrained": ["unconstrained_draw"], - "store_mass_matrix": [ + "adapt_options.mass_matrix_options.store_mass_matrix": [ "mass_matrix_inv", "mass_matrix_eigvals", "mass_matrix_stds", @@ -567,10 +582,23 @@ def _extract(self, results): "divergence_momentum", "divergence_start_gradient", ], + "store_transformed": [ + "transformed_position", + "transformed_gradient", + "transformed_mu", + ], } + def _get_nested(settings, name, default): + parts = name.split(".") + for part in parts: + if part not in settings: + return default + settings = settings[part] + return settings + for setting, names in skips.items(): - if not getattr(self._settings, setting, False): + if not _get_nested(settings_dict["settings"], setting, False): skip_vars.extend(names) draw_batches, stat_batches = results.get_arrow_trace() @@ -638,11 +666,13 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, + blocking: Literal[True], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, + **kwargs, ) -> arviz.InferenceData: ... @@ -660,14 +690,14 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, - blocking: Literal[True], + blocking: Literal[False], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData: ... +) -> _BackgroundSampler: ... @overload @@ -684,14 +714,13 @@ def sample( adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, - blocking: Literal[False], progress_callback: Any | None = None, progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> _BackgroundSampler: ... +) -> arviz.InferenceData: ... def sample( @@ -704,6 +733,7 @@ def sample( seed: int | None = None, save_warmup: bool = True, progress_bar: bool = True, + sampler: Literal["nuts", "mclmc"] = "nuts", adaptation: Literal["diag", "draw_diag", "low_rank", "flow"] = "diag", init_mean: np.ndarray | None = None, return_raw_trace: bool = False, @@ -769,6 +799,12 @@ def sample( return_raw_trace: bool, default=False Return the raw trace object (an apache arrow structure) instead of converting to arviz. + sampler: str, default="nuts" + The sampler to use. One of: + + - ``"nuts"`` (default): No-U-Turn Sampler. + - ``"mclmc"``: Microcanonical Langevin Monte Carlo. + adaptation: str, default="diag" The mass matrix adaptation strategy to use. One of: @@ -873,29 +909,48 @@ def sample( stacklevel=2, ) - if adaptation == "low_rank": - settings = _lib.PyNutsSettings.LowRank(seed) - elif adaptation == "flow": - settings = _lib.PyNutsSettings.Transform(seed) - elif adaptation in ("diag", "draw_diag"): - settings = _lib.PyNutsSettings.Diag(seed) - if adaptation == "draw_diag" or _use_grad_based is False: - settings.use_grad_based_mass_matrix = False + if sampler == "nuts": + if adaptation == "low_rank": + settings = _lib.PyNutsSettings.LowRank(seed) + elif adaptation == "flow": + settings = _lib.PyNutsSettings.Flow(seed) + elif adaptation in ("diag", "draw_diag"): + settings = _lib.PyNutsSettings.Diag(seed) + if adaptation == "draw_diag" or _use_grad_based is False: + settings.use_grad_based_mass_matrix = False + else: + raise ValueError( + f"Unknown adaptation strategy '{adaptation}'. " + f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + ) + elif sampler == "mclmc": + if adaptation == "low_rank": + settings = _lib.PyMclmcSettings.LowRank(seed) + elif adaptation == "flow": + settings = _lib.PyMclmcSettings.Flow(seed) + elif adaptation in ("diag", "draw_diag"): + settings = _lib.PyMclmcSettings.Diag(seed) + if adaptation == "draw_diag" or _use_grad_based is False: + settings.use_grad_based_mass_matrix = False + else: + raise ValueError( + f"Unknown adaptation strategy '{adaptation}'. " + f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + ) else: raise ValueError( - f"Unknown adaptation strategy '{adaptation}'. " - f"Expected one of: 'diag', 'draw_diag', 'low_rank', 'flow'." + f"Unknown sampler '{sampler}'. Expected one of: 'nuts', 'mclmc'." ) + updates = dict(kwargs) if tune is not None: - settings.num_tune = tune + updates["num_tune"] = tune if draws is not None: - settings.num_draws = draws + updates["num_draws"] = draws if chains is not None: - settings.num_chains = chains + updates["num_chains"] = chains - for name, val in kwargs.items(): - setattr(settings, name, val) + settings.update(updates) if cores is None: try: @@ -911,7 +966,7 @@ def sample( if init_mean is None: init_mean = np.zeros(compiled_model.n_dim) - sampler = _BackgroundSampler( + background_sampler = _BackgroundSampler( compiled_model, settings, init_mean, @@ -927,14 +982,14 @@ def sample( ) if not blocking: - return sampler + return background_sampler try: - result = sampler.wait() + result = background_sampler.wait() except KeyboardInterrupt: - result = sampler.abort() + result = background_sampler.abort() except: - sampler.cancel() + background_sampler.cancel() raise return result diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 26ba41e..899381d 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -6,8 +6,8 @@ use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, - Bound, Py, PyAny, PyErr, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyNone}, + Bound, BoundObject, Py, PyAny, PyErr, Python, }; use rand::Rng; use rand_distr::{Distribution, Uniform}; @@ -477,7 +477,7 @@ impl CpuLogpFunc for PyDensity { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -492,6 +492,18 @@ impl CpuLogpFunc for PyDensity { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/pymc.rs b/src/pymc.rs index 44b2e81..b61c5e3 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -6,8 +6,8 @@ use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods}, - Py, PyAny, PyErr, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods, PyNone}, + BoundObject, Py, PyAny, PyErr, PyResult, Python, }; use rand::Rng; @@ -371,7 +371,7 @@ impl CpuLogpFunc for PyMcModelRef<'_> { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -386,6 +386,18 @@ impl CpuLogpFunc for PyMcModelRef<'_> { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/stan.rs b/src/stan.rs index f56fbe5..3c21354 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -7,9 +7,9 @@ use bridgestan::open_library; use itertools::Itertools; use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::exceptions::PyRuntimeError; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyTuple}; +use pyo3::types::{PyDict, PyNone, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; +use pyo3::{prelude::*, BoundObject}; use rand::prelude::Distribution; use rand::{rng, Rng}; use rand_distr::StandardNormal; @@ -202,7 +202,7 @@ where let (mut shape, is_complex) = group .iter() - .map(|&(_, is_complex, ref idx)| (idx, is_complex)) + .map(|&(_, is_complex, idx)| (idx, is_complex)) .fold(None, |acc, (elem_index, &elem_is_complex)| { let (mut shape, is_complex) = acc.unwrap_or((elem_index.clone(), elem_is_complex)); assert!( @@ -630,7 +630,7 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { Ok(()) } - fn new_transformation( + fn init_transformation( &mut self, rng: &mut R, untransformed_position: &[f64], @@ -646,6 +646,18 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { Ok(trafo) } + fn new_transformation( + &mut self, + _rng: &mut R, + _dim: usize, + _chain: u64, + ) -> std::result::Result { + Python::attach(|py| { + let params = PyNone::get(py); + Ok(params.unbind().into()) + }) + } + fn transformation_id(&self, params: &Py) -> std::result::Result { let id = self .transform_adapter diff --git a/src/wrapper.rs b/src/wrapper.rs index 8ba7e50..cabebb6 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -16,19 +16,21 @@ use crate::{ use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, Model, - ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, TransformedNutsSettings, - ZarrAsyncConfig, + ArrowConfig, ArrowTrace, ChainProgress, DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, + FlowNutsSettings, KineticEnergyKind, LowRankMclmcSettings, LowRankNutsSettings, Model, + ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, ZarrAsyncConfig, }; use pyo3::{ - exceptions::{PyTimeoutError, PyValueError}, + exceptions::{PyAttributeError, PyTimeoutError, PyValueError}, intern, prelude::*, - types::PyList, + types::{PyDict, PyList}, }; use pyo3_arrow::PyRecordBatch; use pyo3_object_store::AnyObjectStore; +use pythonize::{depythonize, pythonize}; use rand::{rng, Rng}; +use serde_json::Value as JsonValue; use tokio::runtime::Runtime; use zarrs_object_store::{object_store::limit::LimitStore, AsyncObjectStore}; @@ -103,742 +105,705 @@ impl PyChainProgress { #[pyclass(from_py_object)] #[derive(Clone)] pub struct PyNutsSettings { - inner: Settings, + inner: NutsSettingsKind, +} + +#[derive(Clone, FromPyObject)] +enum PySamplerSettings { + Nuts(PyNutsSettings), + Mclmc(PyMclmcSettings), } #[derive(Clone, Debug)] -enum Settings { - Diag(DiagGradNutsSettings), +enum NutsSettingsKind { + Diag(DiagNutsSettings), LowRank(LowRankNutsSettings), - Transforming(TransformedNutsSettings), + Flow(FlowNutsSettings), } -impl PyNutsSettings { - fn new_diag(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() - }); - let settings = DiagGradNutsSettings { - seed, - ..Default::default() - }; +#[pyclass(from_py_object)] +#[derive(Clone)] +pub struct PyMclmcSettings { + inner: MclmcSettingsKind, +} - Self { - inner: Settings::Diag(settings), - } - } +#[derive(Clone, Debug)] +enum MclmcSettingsKind { + Diag(DiagMclmcSettings), + LowRank(LowRankMclmcSettings), + Flow(FlowMclmcSettings), +} - fn new_low_rank(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() - }); - let settings = LowRankNutsSettings { - seed, - ..Default::default() - }; +macro_rules! unsupported_option_error { + ($option:expr, $adaptation:expr) => { + PyValueError::new_err(format!( + "Option {} not available for {} adaptation", + $option, $adaptation + )) + }; +} - Self { - inner: Settings::LowRank(settings), +macro_rules! with_all_settings_mut { + ($self:expr, $enum_name:ident, $settings:ident => $body:block) => {{ + match &mut $self.inner { + $enum_name::Diag($settings) => $body, + $enum_name::LowRank($settings) => $body, + $enum_name::Flow($settings) => $body, } - } + }}; +} - fn new_tranform_adapt(seed: Option) -> Self { - let seed = seed.unwrap_or_else(|| { - let mut rng = rng(); - rng.next_u64() +macro_rules! set_all_settings_field { + ($self:expr, $enum_name:ident, $field:ident = $value:expr) => {{ + with_all_settings_mut!($self, $enum_name, settings => { + settings.$field = $value; }); - let settings = TransformedNutsSettings { - seed, - ..Default::default() - }; - - Self { - inner: Settings::Transforming(settings), - } - } + }}; + ($self:expr, $enum_name:ident, $field:ident $(. $rest:ident)+ = $value:expr) => {{ + with_all_settings_mut!($self, $enum_name, settings => { + settings.$field$(.$rest)+ = $value; + }); + }}; } -// TODO switch to serde to expose all the options... -#[pymethods] -impl PyNutsSettings { - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn Diag(seed: Option) -> Self { - PyNutsSettings::new_diag(seed) - } - - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn LowRank(seed: Option) -> Self { - PyNutsSettings::new_low_rank(seed) - } - - #[staticmethod] - #[allow(non_snake_case)] - #[pyo3(signature = (seed=None))] - fn Transform(seed: Option) -> Self { - PyNutsSettings::new_tranform_adapt(seed) - } - - #[getter] - fn num_tune(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_tune, - Settings::LowRank(nuts_settings) => nuts_settings.num_tune, - Settings::Transforming(nuts_settings) => nuts_settings.num_tune, - } - } - - #[setter(num_tune)] - fn set_num_tune(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_tune = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_tune = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_tune = val, +macro_rules! with_diag_or_low_rank_settings_mut { + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + $enum_name::Diag($settings) => $body, + $enum_name::LowRank($settings) => $body, + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } - } + }}; +} - #[getter] - fn num_chains(&self) -> usize { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_chains, - Settings::LowRank(nuts_settings) => nuts_settings.num_chains, - Settings::Transforming(nuts_settings) => nuts_settings.num_chains, +macro_rules! with_diag_settings_mut { + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + $enum_name::Diag($settings) => $body, + $enum_name::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } - } + }}; +} - #[setter(num_chains)] - fn set_num_chains(&mut self, val: usize) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_chains = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_chains = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_chains = val, +macro_rules! with_low_rank_settings_mut { + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + $enum_name::LowRank($settings) => $body, + $enum_name::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + $enum_name::Flow(_) => return Err(unsupported_option_error!($option, "flow")), } - } + }}; +} - #[getter] - fn num_draws(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_draws, - Settings::LowRank(nuts_settings) => nuts_settings.num_draws, - Settings::Transforming(nuts_settings) => nuts_settings.num_draws, +macro_rules! with_flow_settings_mut { + ($self:expr, $enum_name:ident, $option:expr, $settings:ident => $body:block) => {{ + match &mut $self.inner { + $enum_name::Flow($settings) => $body, + $enum_name::Diag(_) => return Err(unsupported_option_error!($option, "diag")), + $enum_name::LowRank(_) => return Err(unsupported_option_error!($option, "low-rank")), } - } + }}; +} - #[setter(num_draws)] - fn set_num_draws(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.num_draws = val, - Settings::LowRank(nuts_settings) => nuts_settings.num_draws = val, - Settings::Transforming(nuts_settings) => nuts_settings.num_draws = val, - } - } +macro_rules! try_shared_euclidean_adapt_update { + ($self:expr, $enum_name:ident, $name:expr, $value:expr) => {{ + match $name { + "window_switch_freq" => { + let value: u64 = $value.extract()?; + match &mut $self.inner { + $enum_name::Diag(settings) => { + settings.adapt_options.mass_matrix_switch_freq = value + } + $enum_name::LowRank(settings) => { + settings.adapt_options.mass_matrix_switch_freq = value + } + $enum_name::Flow(settings) => { + settings.adapt_options.transform_update_freq = value + } + } + true + } + "early_window_switch_freq" => { + let value: u64 = $value.extract()?; + with_diag_or_low_rank_settings_mut!( + $self, + $enum_name, + "early_window_switch_freq", + settings => { + settings.adapt_options.early_mass_matrix_switch_freq = value; + } + ); + true + } + "initial_step" => { + let value: f64 = $value.extract()?; + set_all_settings_field!( + $self, + $enum_name, + adapt_options.step_size_settings.initial_step = value + ); + true + } + "target_accept" => { + let value: f64 = $value.extract()?; + set_all_settings_field!( + $self, + $enum_name, + adapt_options.step_size_settings.target_accept = value + ); + true + } + "max_step_size" => { + let value: f64 = $value.extract()?; + set_all_settings_field!( + $self, + $enum_name, + adapt_options + .step_size_settings + .adapt_options + .dual_average + .max_step_size = value + ); + true + } + "store_mass_matrix" => { + let value: bool = $value.extract()?; + with_diag_or_low_rank_settings_mut!( + $self, + $enum_name, + "store_mass_matrix", + settings => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = value; + } + ); + true + } + "use_grad_based_mass_matrix" => { + let value: bool = $value.extract()?; + with_diag_settings_mut!( + $self, + $enum_name, + "use_grad_based_mass_matrix", + settings => { + settings.adapt_options.mass_matrix_options.use_grad_based_estimate = value; + } + ); + true + } + "mass_matrix_switch_freq" => { + let value: u64 = $value.extract()?; + with_diag_or_low_rank_settings_mut!( + $self, + $enum_name, + "mass_matrix_switch_freq", + settings => { + settings.adapt_options.mass_matrix_switch_freq = value; + } + ); + true + } + "mass_matrix_eigval_cutoff" => { + let value: Option = $value.extract()?; + if let Some(value) = value { + with_low_rank_settings_mut!( + $self, + $enum_name, + "mass_matrix_eigval_cutoff", + settings => { + settings.adapt_options.mass_matrix_options.eigval_cutoff = value; + } + ); + } + true + } + "mass_matrix_gamma" => { + let value: Option = $value.extract()?; + if let Some(value) = value { + with_low_rank_settings_mut!( + $self, + $enum_name, + "mass_matrix_gamma", + settings => { + settings.adapt_options.mass_matrix_options.gamma = value; + } + ); + } + true + } + "train_on_orbit" => { + let value: bool = $value.extract()?; + with_flow_settings_mut!( + $self, + $enum_name, + "train_on_orbit", + settings => { + settings.adapt_options.use_orbit_for_training = value; + } + ); + true + } + "step_size_adapt_method" => { + let method = match $value.extract::() { + Ok(method) => match method.as_str() { + "dual_average" => StepSizeAdaptMethod::DualAverage, + "adam" => StepSizeAdaptMethod::Adam, + _ => { + if let Ok(step_size) = method.parse::() { + StepSizeAdaptMethod::Fixed(step_size) + } else { + return Err(PyValueError::new_err( + "step_size_adapt_method must be a positive float when using fixed step size", + )); + } + } + }, + _ => { + return Err(PyValueError::new_err( + "step_size_adapt_method must be a string", + )); + } + }; - #[getter] - fn window_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(nuts_settings) => { - Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) - } - Settings::LowRank(nuts_settings) => { - Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) - } - Settings::Transforming(nuts_settings) => { - Ok(nuts_settings.adapt_options.transform_update_freq) + set_all_settings_field!( + $self, + $enum_name, + adapt_options.step_size_settings.adapt_options.method = method + ); + true + } + "step_size_adam_learning_rate" => { + let value: Option = $value.extract()?; + if let Some(value) = value { + set_all_settings_field!( + $self, + $enum_name, + adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = value + ); + } + true } + "step_size_jitter" => { + let mut value: Option = $value.extract()?; + if let Some(jitter) = value { + if jitter < 0.0 { + return Err(PyValueError::new_err("step_size_jitter must be positive")); + } + if jitter == 0.0 { + value = None; + } + } + set_all_settings_field!( + $self, + $enum_name, + adapt_options.step_size_settings.jitter = value + ); + true + } + "store_unconstrained" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_unconstrained = value); + true + } + "store_gradient" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_gradient = value); + true + } + "num_tune" => { + let value: u64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_tune = value); + true + } + "num_chains" => { + let value: usize = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_chains = value); + true + } + "num_draws" => { + let value: u64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, num_draws = value); + true + } + "store_transformed" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_transformed = value); + true + } + "store_divergences" => { + let value: bool = $value.extract()?; + set_all_settings_field!($self, $enum_name, store_divergences = value); + true + } + "max_energy_error" => { + let value: f64 = $value.extract()?; + set_all_settings_field!($self, $enum_name, max_energy_error = value); + true + } + _ => false, } - } + }}; +} - #[setter(window_switch_freq)] - fn set_window_switch_freq(&mut self, val: u64) -> Result<()> { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.mass_matrix_switch_freq = val; - Ok(()) - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.mass_matrix_switch_freq = val; - Ok(()) - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.transform_update_freq = val; - Ok(()) - } - } - } +fn random_seed(seed: Option) -> u64 { + seed.unwrap_or_else(|| { + let mut rng = rng(); + rng.next_u64() + }) +} - #[getter] - fn early_window_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(nuts_settings) => { - Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) - } - Settings::LowRank(nuts_settings) => { - Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) - } - Settings::Transforming(_) => { - bail!("Option early_window_switch_freq not availbale for transformation adaptation") - } +fn update_nuts_from_nested_dict( + inner: &mut NutsSettingsKind, + value: &Bound<'_, PyAny>, +) -> PyResult<()> { + match inner { + NutsSettingsKind::Diag(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } - } - - #[setter(early_window_switch_freq)] - fn set_early_window_switch_freq(&mut self, val: u64) -> Result<()> { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; - Ok(()) - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; - Ok(()) - } - Settings::Transforming(_) => { - bail!("Option early_window_switch_freq not availbale for transformation adaptation") - } + NutsSettingsKind::LowRank(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } - } - - #[getter] - fn initial_step(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step - } + NutsSettingsKind::Flow(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } } + Ok(()) +} - #[setter(initial_step)] - fn set_initial_step(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.initial_step = val; - } +fn update_mclmc_from_nested_dict( + inner: &mut MclmcSettingsKind, + value: &Bound<'_, PyAny>, +) -> PyResult<()> { + match inner { + MclmcSettingsKind::Diag(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } - } - - #[getter] - fn maxdepth(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.maxdepth, - Settings::LowRank(nuts_settings) => nuts_settings.maxdepth, - Settings::Transforming(nuts_settings) => nuts_settings.maxdepth, + MclmcSettingsKind::LowRank(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } - } - - #[setter(maxdepth)] - fn set_maxdepth(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.maxdepth = val, - Settings::LowRank(nuts_settings) => nuts_settings.maxdepth = val, - Settings::Transforming(nuts_settings) => nuts_settings.maxdepth = val, + MclmcSettingsKind::Flow(settings) => { + *settings = depythonize(value).map_err(|err| PyValueError::new_err(err.to_string()))?; } } + Ok(()) +} - #[getter] - fn mindepth(&self) -> u64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.mindepth, - Settings::LowRank(nuts_settings) => nuts_settings.mindepth, - Settings::Transforming(nuts_settings) => nuts_settings.mindepth, +fn nuts_to_nested_json(inner: &NutsSettingsKind) -> PyResult { + match inner { + NutsSettingsKind::Diag(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } - } - - #[setter(mindepth)] - fn set_mindepth(&mut self, val: u64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.mindepth = val, - Settings::LowRank(nuts_settings) => nuts_settings.mindepth = val, - Settings::Transforming(nuts_settings) => nuts_settings.mindepth = val, + NutsSettingsKind::LowRank(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } - } - - #[getter] - fn store_gradient(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_gradient, - Settings::LowRank(nuts_settings) => nuts_settings.store_gradient, - Settings::Transforming(nuts_settings) => nuts_settings.store_gradient, + NutsSettingsKind::Flow(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } } +} - #[setter(store_gradient)] - fn set_store_gradient(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_gradient = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_gradient = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_gradient = val, +fn mclmc_to_nested_json(inner: &MclmcSettingsKind) -> PyResult { + match inner { + MclmcSettingsKind::Diag(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } - } - - #[getter] - fn store_unconstrained(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained, - Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained, - Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained, + MclmcSettingsKind::LowRank(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) + } + MclmcSettingsKind::Flow(settings) => { + serde_json::to_value(settings).map_err(|err| PyValueError::new_err(err.to_string())) } } +} - #[setter(store_unconstrained)] - fn set_store_unconstrained(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained = val, +impl PyNutsSettings { + fn new_diag(seed: Option) -> Self { + let settings = DiagNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::Diag(settings), } } - #[getter] - fn store_divergences(&self) -> bool { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_divergences, - Settings::LowRank(nuts_settings) => nuts_settings.store_divergences, - Settings::Transforming(nuts_settings) => nuts_settings.store_divergences, + fn new_low_rank(seed: Option) -> Self { + let settings = LowRankNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::LowRank(settings), } } - #[setter(store_divergences)] - fn set_store_divergences(&mut self, val: bool) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.store_divergences = val, - Settings::LowRank(nuts_settings) => nuts_settings.store_divergences = val, - Settings::Transforming(nuts_settings) => nuts_settings.store_divergences = val, + fn new_flow(seed: Option) -> Self { + let settings = FlowNutsSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: NutsSettingsKind::Flow(settings), } } - #[getter] - fn max_energy_error(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => nuts_settings.max_energy_error, - Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error, - Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error, - } + fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { + update_nuts_from_nested_dict(&mut self.inner, value) } - #[setter(max_energy_error)] - fn set_max_energy_error(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => nuts_settings.max_energy_error = val, - Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error = val, - Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error = val, - } + fn to_nested_json(&self) -> PyResult { + nuts_to_nested_json(&self.inner) } - #[getter] - fn set_target_accept(&self) -> f64 { - match &self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept + fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + match name { + "maxdepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, maxdepth = value); } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept + "mindepth" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, mindepth = value); } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept + "check_turning" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, check_turning = value); } - } - } - - #[setter(target_accept)] - fn target_accept(&mut self, val: f64) { - match &mut self.inner { - Settings::Diag(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val + "target_integration_time" => { + let value: Option = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, target_integration_time = value); } - Settings::LowRank(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val + "extra_doublings" => { + let value: u64 = value.extract()?; + set_all_settings_field!(self, NutsSettingsKind, extra_doublings = value); } - Settings::Transforming(nuts_settings) => { - nuts_settings.adapt_options.step_size_settings.target_accept = val + _ => { + if try_shared_euclidean_adapt_update!(self, NutsSettingsKind, name, value) { + // handled above + } else { + match name { + "microcanonical_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!( + self, + NutsSettingsKind, + trajectory_kind = KineticEnergyKind::Microcanonical + ); + } + } + "exact_normal_trajectory" => { + let value: bool = value.extract()?; + if value { + set_all_settings_field!( + self, + NutsSettingsKind, + trajectory_kind = KineticEnergyKind::ExactNormal + ); + } + } + _ => { + return Err(PyAttributeError::new_err(format!( + "Unknown settings attribute: {name}", + ))); + } + } + } } } + Ok(()) } +} - #[getter] - fn store_mass_matrix(&self) -> Result { - match &self.inner { - Settings::LowRank(settings) => { - Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) - } - Settings::Diag(settings) => { - Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) - } - Settings::Transforming(_) => Ok(false), +impl PyMclmcSettings { + fn new_diag(seed: Option) -> Self { + let settings = DiagMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::Diag(settings), } } - #[setter(store_mass_matrix)] - fn set_store_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(settings) => { - settings.adapt_options.mass_matrix_options.store_mass_matrix = val; - Ok(()) - } - Settings::Diag(settings) => { - settings.adapt_options.mass_matrix_options.store_mass_matrix = val; - Ok(()) - } - Settings::Transforming(_) => { - bail!("Option store_mass_matrix not availbale for transformation adaptation") - } + fn new_low_rank(seed: Option) -> Self { + let settings = LowRankMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::LowRank(settings), } } - #[getter] - fn use_grad_based_mass_matrix(&self) -> Result { - match &self.inner { - Settings::LowRank(_) => { - bail!("non-grad based mass matrix not available for low-rank adaptation") - } - Settings::Transforming(_) => { - bail!("non-grad based mass matrix not available for transforming adaptation") - } - Settings::Diag(diag) => Ok(diag - .adapt_options - .mass_matrix_options - .use_grad_based_estimate), + fn new_flow(seed: Option) -> Self { + let settings = FlowMclmcSettings { + seed: random_seed(seed), + ..Default::default() + }; + Self { + inner: MclmcSettingsKind::Flow(settings), } } - #[setter(use_grad_based_mass_matrix)] - fn set_use_grad_based_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(_) => { - bail!("non-grad based mass matrix not available for low-rank adaptation"); - } - Settings::Transforming(_) => { - bail!("non-grad based mass matrix not available for transforming adaptation"); - } - Settings::Diag(diag) => { - diag.adapt_options - .mass_matrix_options - .use_grad_based_estimate = val; - } - } - Ok(()) + fn update_from_nested_dict(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { + update_mclmc_from_nested_dict(&mut self.inner, value) } - #[getter] - fn mass_matrix_switch_freq(&self) -> Result { - match &self.inner { - Settings::Diag(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), - Settings::LowRank(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), - Settings::Transforming(_) => { - bail!("mass_matrix_switch_freq not available for transforming adaptation"); - } - } + fn to_nested_json(&self) -> PyResult { + mclmc_to_nested_json(&self.inner) } - #[setter(mass_matrix_switch_freq)] - fn set_mass_matrix_switch_freq(&mut self, val: u64) -> Result<()> { - match &mut self.inner { - Settings::Diag(settings) => settings.adapt_options.mass_matrix_switch_freq = val, - Settings::LowRank(settings) => settings.adapt_options.mass_matrix_switch_freq = val, - Settings::Transforming(_) => { - bail!("mass_matrix_switch_freq not available for transforming adaptation"); + fn apply_update(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + match name { + "step_size" => { + let value: f64 = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, step_size = value); } - } - Ok(()) - } - - #[getter] - fn mass_matrix_eigval_cutoff(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.eigval_cutoff), - Settings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + "momentum_decoherence_length" => { + let value: f64 = value.extract()?; + set_all_settings_field!( + self, + MclmcSettingsKind, + momentum_decoherence_length = value + ); } - Settings::Transforming(_) => { - bail!("eigenvalue cutoff not available for transfor adaptation"); + "subsample_frequency" => { + let value: f64 = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, subsample_frequency = value); } - } - } - - #[setter(mass_matrix_eigval_cutoff)] - fn set_mass_matrix_eigval_cutoff(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, - Settings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + "dynamic_step_size" => { + let value: bool = value.extract()?; + set_all_settings_field!(self, MclmcSettingsKind, dynamic_step_size = value); } - Settings::Transforming(_) => { - bail!("eigenvalue cutoff not available for transfor adaptation"); + _ => { + if try_shared_euclidean_adapt_update!(self, MclmcSettingsKind, name, value) { + // handled above + } else { + return Err(PyAttributeError::new_err(format!( + "Unknown settings attribute: {name}", + ))); + } } } Ok(()) } +} - #[getter] - fn mass_matrix_gamma(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.gamma), - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("gamma not available for transform adaptation"); - } - } +// TODO switch to serde to expose all the options... +#[pymethods] +impl PyNutsSettings { + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Diag(seed: Option) -> Self { + PyNutsSettings::new_diag(seed) } - #[setter(mass_matrix_gamma)] - fn set_mass_matrix_gamma(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => { - inner.adapt_options.mass_matrix_options.gamma = val; - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(_) => { - bail!("gamma not available for transform adaptation"); - } - } - Ok(()) + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn LowRank(seed: Option) -> Self { + PyNutsSettings::new_low_rank(seed) } - #[getter] - fn train_on_orbit(&self) -> Result { - match &self.inner { - Settings::LowRank(_) => { - bail!("gamma not available for low rank mass matrix adaptation"); - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(inner) => Ok(inner.adapt_options.use_orbit_for_training), - } + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Flow(seed: Option) -> Self { + PyNutsSettings::new_flow(seed) } - #[setter(train_on_orbit)] - fn set_train_on_orbit(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(_) => { - bail!("gamma not available for low rank mass matrix adaptation"); - } - Settings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation"); - } - Settings::Transforming(inner) => inner.adapt_options.use_orbit_for_training = val, + fn update(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> { + for (key, value) in kwargs.iter() { + let key: String = key.extract()?; + self.apply_update(&key, &value)?; } Ok(()) } - #[getter] - fn check_turning(&self) -> Result { - match &self.inner { - Settings::LowRank(inner) => Ok(inner.check_turning), - Settings::Diag(inner) => Ok(inner.check_turning), - Settings::Transforming(inner) => Ok(inner.check_turning), - } + fn __setattr__(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + self.apply_update(name, value) } - #[setter(check_turning)] - fn set_check_turning(&mut self, val: bool) -> Result<()> { - match &mut self.inner { - Settings::LowRank(inner) => { - inner.check_turning = val; - } - Settings::Diag(inner) => { - inner.check_turning = val; - } - Settings::Transforming(inner) => { - inner.check_turning = val; - } - } - Ok(()) + fn update_settings(&mut self, settings: &Bound<'_, PyDict>) -> PyResult<()> { + self.update_from_nested_dict(settings.as_any()) } - #[getter] - fn step_size_adapt_method(&self) -> String { - let method = match &self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.adapt_options.method, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.adapt_options.method, - Settings::Transforming(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method - } + fn as_dict(&self, py: Python<'_>) -> PyResult> { + let settings = self.to_nested_json()?; + let adaptation = match self.inner { + NutsSettingsKind::Diag(_) => "diag", + NutsSettingsKind::LowRank(_) => "low_rank", + NutsSettingsKind::Flow(_) => "flow", }; + let value = serde_json::json!({ + "sampler": "nuts", + "adaptation": adaptation, + "settings": settings, + }); + let obj = pythonize(py, &value).map_err(|err| PyValueError::new_err(err.to_string()))?; + Ok(obj.unbind()) + } +} - match method { - nuts_rs::StepSizeAdaptMethod::DualAverage => "dual_average", - nuts_rs::StepSizeAdaptMethod::Adam => "adam", - nuts_rs::StepSizeAdaptMethod::Fixed(_) => "fixed", - } - .to_string() +#[pymethods] +impl PyMclmcSettings { + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Diag(seed: Option) -> Self { + PyMclmcSettings::new_diag(seed) } - #[setter(step_size_adapt_method)] - fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { - let method = Python::attach(|py| { - if let Ok(method) = method.extract::(py) { - match method.as_str() { - "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), - "adam" => Ok(StepSizeAdaptMethod::Adam), - _ => { - if let Ok(step_size) = method.parse::() { - Ok(StepSizeAdaptMethod::Fixed(step_size)) - } else { - bail!("step_size_adapt_method must be a positive float when using fixed step size"); - } - } - } - } else { - bail!("step_size_adapt_method must be a string"); - } - })?; + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn LowRank(seed: Option) -> Self { + PyMclmcSettings::new_low_rank(seed) + } - match &mut self.inner { - Settings::LowRank(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - Settings::Diag(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - Settings::Transforming(inner) => { - inner.adapt_options.step_size_settings.adapt_options.method = method - } - }; - Ok(()) + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Flow(seed: Option) -> Self { + PyMclmcSettings::new_flow(seed) } - #[getter] - fn step_size_adam_learning_rate(&self) -> Option { - match &self.inner { - Settings::LowRank(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } - Settings::Diag(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } - Settings::Transforming(inner) => { - if let StepSizeAdaptMethod::Adam = - inner.adapt_options.step_size_settings.adapt_options.method - { - Some( - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate, - ) - } else { - None - } - } + fn update(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> { + for (key, value) in kwargs.iter() { + let key: String = key.extract()?; + self.apply_update(&key, &value)?; } + Ok(()) } - #[setter(step_size_adam_learning_rate)] - fn set_step_size_adam_learning_rate(&mut self, val: Option) -> Result<()> { - let Some(val) = val else { - return Ok(()); - }; - match &mut self.inner { - Settings::LowRank(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - Settings::Diag(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - Settings::Transforming(inner) => { - inner - .adapt_options - .step_size_settings - .adapt_options - .adam - .learning_rate = val - } - }; - Ok(()) + fn __setattr__(&mut self, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { + self.apply_update(name, value) } - #[getter(step_size_jitter)] - fn step_size_jitter(&self) -> Option { - match &self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter, - Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter, - } + fn update_settings(&mut self, settings: &Bound<'_, PyDict>) -> PyResult<()> { + self.update_from_nested_dict(settings.as_any()) } - #[setter(step_size_jitter)] - fn set_step_size_jitter(&mut self, mut val: Option) -> PyResult<()> { - if let Some(val) = val { - if val < 0.0 { - return Err(PyValueError::new_err("step_size_jitter must be positive")); - } - } - if let Some(jitter) = val { - if jitter == 0.0 { - val = None; - } - } - match &mut self.inner { - Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter = val, - Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter = val, - Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter = val, - } - Ok(()) + fn as_dict(&self, py: Python<'_>) -> PyResult> { + let settings = self.to_nested_json()?; + let adaptation = match self.inner { + MclmcSettingsKind::Diag(_) => "diag", + MclmcSettingsKind::LowRank(_) => "low_rank", + MclmcSettingsKind::Flow(_) => "flow", + }; + let value = serde_json::json!({ + "sampler": "mclmc", + "adaptation": adaptation, + "settings": settings, + }); + let obj = pythonize(py, &value).map_err(|err| PyValueError::new_err(err.to_string()))?; + Ok(obj.unbind()) } } @@ -972,7 +937,7 @@ struct PySampler(Mutex<(SamplerState, Runtime)>); impl PySampler { fn new( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: M, progress_type: ProgressType, @@ -987,31 +952,59 @@ impl PySampler { match &mut store.0 { InnerPyStorage::Arrow => { let storage_config = ArrowConfig::new(); - match settings.inner { - Settings::LowRank(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } - Settings::Diag(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } - Settings::Transforming(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningArrow(sampler).into(), - tokio_rt, - )))) - } + match settings { + PySamplerSettings::Nuts(settings) => match settings.inner { + NutsSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + }, + PySamplerSettings::Mclmc(settings) => match settings.inner { + MclmcSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + }, } } InnerPyStorage::Zarr(store) => { @@ -1025,31 +1018,59 @@ impl PySampler { let store = Arc::new(store); let storage_config = ZarrAsyncConfig::new(tokio_rt.handle().clone(), store); let storage_config = storage_config.with_chunk_size(16); - match settings.inner { - Settings::LowRank(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } - Settings::Diag(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } - Settings::Transforming(settings) => { - let sampler = - Sampler::new(model, settings, storage_config, cores, callback)?; - Ok(PySampler(Mutex::new(( - SamplerState::RunningZarr(sampler).into(), - tokio_rt, - )))) - } + match settings { + PySamplerSettings::Nuts(settings) => match settings.inner { + NutsSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + NutsSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + }, + PySamplerSettings::Mclmc(settings) => match settings.inner { + MclmcSettingsKind::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + MclmcSettingsKind::Flow(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + }, } } } @@ -1149,7 +1170,7 @@ impl PySampler { impl PySampler { #[staticmethod] fn from_pymc( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: PyMcModel, progress_type: ProgressType, @@ -1170,7 +1191,7 @@ impl PySampler { #[staticmethod] fn from_stan( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: StanModel, progress_type: ProgressType, @@ -1191,7 +1212,7 @@ impl PySampler { #[staticmethod] fn from_pyfunc( - settings: PyNutsSettings, + settings: PySamplerSettings, cores: usize, model: PyModel, progress_type: ProgressType, @@ -1705,6 +1726,7 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;