Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ pyo3-object_store = "0.9.0"
zarrs = { version = "0.23.2", features = ["async"] }
zarrs_object_store = "0.6.0"
tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] }
pyo3-arrow = "0.16.0"
arrow = { version = "57.0.0", features = ["json"] }
pyo3-arrow = "0.17.0"
arrow = { version = "58.0.0", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
pythonize = "0.28.0"

[dependencies.pyo3]
version = "0.28.0"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pyarrow >= 12.0.0",
"arro3-core >= 0.6.0",
"pandas >= 2.0",
"platformdirs >= 3.0.0",
"xarray >= 2025.01.2",
"arviz >= 0.20.0,<1.0",
"obstore >= 0.8.0",
Expand Down
3 changes: 2 additions & 1 deletion python/nutpie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from nutpie import _lib
from nutpie._lib import store as zarr_store
from nutpie.compile_pymc import compile_pymc_model
from nutpie.compile_stan import compile_stan_model
from nutpie.compile_stan import compile_stan_model, prune_stan_cache
from nutpie.sample import sample

ChainProgress = _lib.PyChainProgress
Expand All @@ -12,6 +12,7 @@
"ChainProgress",
"compile_pymc_model",
"compile_stan_model",
"prune_stan_cache",
"sample",
"zarr_store",
]
219 changes: 193 additions & 26 deletions python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import datetime
import hashlib
import json
import shutil
import tempfile
from dataclasses import dataclass, replace
from importlib.util import find_spec
Expand Down Expand Up @@ -144,6 +148,105 @@ def coords(self):
return self._coords


def _stan_cache_key(
code: str,
extra_compile_args: Optional[list[str]],
extra_stanc_args: Optional[list[str]],
) -> str:
"""Return a SHA-256 hex digest identifying a unique compilation job."""
import bridgestan

fingerprint = json.dumps(
{
"code": code,
"extra_compile_args": sorted(extra_compile_args or []),
"extra_stanc_args": sorted(extra_stanc_args or []),
"bridgestan_version": bridgestan.__version__,
},
sort_keys=True,
)
return hashlib.sha256(fingerprint.encode()).hexdigest()


def _stan_cache_dir() -> Path:
"""Return (and create) the directory where compiled Stan models are cached."""
import platformdirs

cache_dir = Path(platformdirs.user_cache_dir("nutpie")) / "stan"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir


def prune_stan_cache(
max_entries: int = 16,
min_age: datetime.timedelta = datetime.timedelta(weeks=2),
) -> None:
"""Remove old entries from the Stan compilation cache.

Entries are only considered for removal if they are older than *min_age*.
Among those, the oldest ones are removed until at most *max_entries*
entries remain.

Parameters
----------
max_entries:
Maximum number of cache entries to keep. Defaults to 16.
min_age:
Entries younger than this are never removed, regardless of how many
entries exist. Defaults to 2 weeks.
"""
cache_dir = _stan_cache_dir()
now = datetime.datetime.now(tz=datetime.timezone.utc)

# Collect all valid (marker exists) entries with their mtime.
entries = []
for entry_dir in cache_dir.iterdir():
if not entry_dir.is_dir():
continue
marker = entry_dir / "ok"
if not marker.exists():
continue
mtime = datetime.datetime.fromtimestamp(
marker.stat().st_mtime, tz=datetime.timezone.utc
)
entries.append((mtime, entry_dir))

if len(entries) <= max_entries:
return

# Only entries older than min_age are candidates for eviction.
candidates = sorted(
[(mtime, d) for mtime, d in entries if (now - mtime) >= min_age]
)

n_to_remove = len(entries) - max_entries
for _, entry_dir in candidates[:n_to_remove]:
shutil.rmtree(entry_dir, ignore_errors=True)


def _compile_stan_model(
model_name: str,
code: str,
build_dir: Path,
make_args: list[str],
stanc_args: list[str],
) -> Path:
"""Write *code* into *build_dir*, compile it, and return the path to the shared library."""
import bridgestan

model_path = (
build_dir.joinpath("name")
.with_name(model_name) # This verifies that it is a valid filename
.with_suffix(".stan")
)
model_path.write_text(code)
so_path = bridgestan.compile_model(
model_path, make_args=make_args, stanc_args=stanc_args
)
bridgestan.compile.windows_dll_path_setup()
return so_path


def compile_stan_model(
*,
code: Optional[str] = None,
Expand All @@ -154,7 +257,46 @@ def compile_stan_model(
coords: Optional[dict[str, Any]] = None,
model_name: Optional[str] = None,
cleanup: bool = True,
cache: bool = False,
prune_cache: bool = True,
) -> CompiledStanModel:
"""Compile a Stan model and return a :class:`CompiledStanModel`.

Parameters
----------
code:
Stan model source code as a string.
filename:
Path to a ``.stan`` file. Mutually exclusive with *code*.
extra_compile_args:
Extra arguments forwarded to the C++ compiler via BridgeStan's
``make_args``.
extra_stanc_args:
Extra arguments forwarded to the Stan compiler (``stanc``).
dims:
Variable dimension names, e.g. ``{"alpha": ["county"]}``.
coords:
Coordinate labels for each dimension, e.g.
``{"county": ["Hennepin", "Ramsey", ...]}``.
model_name:
Base name used for the ``.stan`` file. Defaults to ``"model"``.
cleanup:
Remove the temporary build directory after compilation. Has no
effect when *cache* is ``True`` (the build directory is the cache
entry and is never deleted).
cache:
When ``True``, compile the model into a persistent directory under
the user cache directory (``~/.cache/nutpie/stan`` on Linux/macOS,
``%LOCALAPPDATA%\\nutpie\\stan`` on Windows) and reuse it on
subsequent calls with identical arguments and the same BridgeStan
version. A marker file ``ok`` is written only after a successful
build, so interrupted or failed compilations are never reused.
Defaults to ``False``.
prune_cache:
When ``True`` (the default), call :func:`prune_stan_cache` after
each new compilation to evict old cache entries. Has no effect
when *cache* is ``False``.
"""
if find_spec("bridgestan") is None:
raise ImportError(
"BridgeStan is not installed in the current environment. "
Expand All @@ -180,33 +322,58 @@ def compile_stan_model(
if model_name is None:
model_name = "model"

basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
try:
model_path = (
Path(basedir.name)
.joinpath("name")
.with_name(model_name) # This verifies that it is a valid filename
.with_suffix(".stan")
)
model_path.write_text(code)
make_args = ["STAN_THREADS=true"]
if extra_compile_args:
make_args.extend(extra_compile_args)
stanc_args = []
if extra_stanc_args:
stanc_args.extend(extra_stanc_args)
so_path = bridgestan.compile_model(
model_path, make_args=make_args, stanc_args=stanc_args
)
# Set necessary library loading paths
bridgestan.compile.windows_dll_path_setup()
library = _lib.StanLibrary(so_path)
finally:
make_args = ["STAN_THREADS=true"]
if extra_compile_args:
make_args.extend(extra_compile_args)
stanc_args = []
if extra_stanc_args:
stanc_args.extend(extra_stanc_args)

if cache:
digest = _stan_cache_key(code, extra_compile_args, extra_stanc_args)
entry_dir = _stan_cache_dir() / digest
marker = entry_dir / "ok"

so_path_file = entry_dir / "so_path.txt"

if marker.exists():
# Cache hit: touch the marker to record recent use, then load.
marker.touch()
so_path = Path(so_path_file.read_text())
if not so_path.exists():
raise FileNotFoundError(
f"Cached Stan library not found: {so_path}. "
"The cache entry may be corrupt; delete it and recompile."
)
bridgestan.compile.windows_dll_path_setup()
library = _lib.StanLibrary(str(so_path))
else:
# Cache miss: compile directly into the cache entry directory so
# that all relative loading paths inside the .so remain valid.
entry_dir.mkdir(parents=True, exist_ok=True)
so_path = _compile_stan_model(
model_name, code, entry_dir, make_args, stanc_args
)
# Write the .so path before the marker so the marker is only
# ever present once so_path.txt is fully written.
so_path_file.write_text(str(so_path))
marker.write_text("")
library = _lib.StanLibrary(str(so_path))
if prune_cache:
prune_stan_cache()
else:
basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
try:
if cleanup:
basedir.cleanup()
except Exception: # noqa: BLE001
pass
so_path = _compile_stan_model(
model_name, code, Path(basedir.name), make_args, stanc_args
)
library = _lib.StanLibrary(str(so_path))
finally:
try:
if cleanup:
basedir.cleanup()
except Exception: # noqa: BLE001
pass

return CompiledStanModel(
code=code,
Expand Down
Loading
Loading