diff --git a/Cargo.lock b/Cargo.lock index f774bd81..039b6c1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2228,7 +2228,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.16.7" +version = "0.16.8" dependencies = [ "anyhow", "arrow", diff --git a/README.md b/README.md index ef5f7105..6a8ec207 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ compiled_model = nutpie.compile_pymc_model(pymc_model) trace_pymc = nutpie.sample(compiled_model) ``` -`trace_pymc` now contains an ArviZ `InferenceData` object, including sampling +`trace_pymc` now contains a `DataTree` object, including sampling statistics and the posterior of the variables defined above. We can also control the sampler in a non-blocking way: @@ -111,7 +111,7 @@ sampler.resume() # Wait for the sampler to finish (up to timeout seconds) sampler.wait(timeout=0.1) # Note that not passing any timeout to `wait` will -# wait until the sampler finishes, then return the InferenceData object: +# wait until the sampler finishes, then return the DataTree object: idata = sampler.wait() # or we can also abort the sampler (and return the incomplete trace) diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd index a045490e..28ed998e 100644 --- a/docs/pymc-usage.qmd +++ b/docs/pymc-usage.qmd @@ -97,7 +97,7 @@ information about how each chain is doing: (100 to 1000) are a sign that the parameterization of the model is not ideal, and the sampler is very inefficient. -After sampling, this returns an `arviz` InferenceData object that you can use to +After sampling, this returns a `DataTree` object that follows the InferenceData schema, which you can use to analyze the trace. For example, we should check the effective sample size: @@ -110,7 +110,7 @@ az.ess(trace) and take a look at a trace plot: ```{python} -az.plot_trace(trace); +az.plot_trace_dist(trace); ``` ### Choosing the backend diff --git a/pyproject.toml b/pyproject.toml index ad8a07c4..468e2821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "arro3-core >= 0.6.0", "pandas >= 2.0", "xarray >= 2025.01.2", - "arviz >= 0.20.0,<1.0", + "arviz >= 1.0", "obstore >= 0.8.0", "zarr >= 3.1.0", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 336bf34a..484fd0af 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -511,7 +511,7 @@ def compile_pymc_model( freeze_model : bool | None Freeze all dimensions and shared variables to treat them as compile time constants. - + Returns ------- compiled_model : CompiledPyMCModel diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 908977b9..c3109b68 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pyarrow +from xarray import DataTree from nutpie import _lib # type: ignore @@ -96,11 +97,14 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars ) + data = { + "posterior": data_posterior, + "sample_stats": stats_posterior, + "warmup_posterior": data_tune, + "warmup_sample_stats": stats_tune, + } return arviz.from_dict( - data_posterior, - sample_stats=stats_posterior, - warmup_posterior=data_tune, - warmup_sample_stats=stats_tune, + data, dims=dims, **kwargs, ) @@ -639,7 +643,7 @@ def sample( progress_style: str | None = None, progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, -) -> arviz.InferenceData: ... +) -> DataTree: ... @overload @@ -663,7 +667,7 @@ def sample( progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData: ... +) -> DataTree: ... @overload @@ -710,7 +714,7 @@ def sample( progress_rate: int = 100, zarr_store: _ZarrStoreType | None = None, **kwargs, -) -> arviz.InferenceData | _BackgroundSampler: +) -> DataTree | _BackgroundSampler: """Sample the posterior distribution for a compiled model. Parameters @@ -804,8 +808,8 @@ def sample( Returns ------- - trace : arviz.InferenceData - An ArviZ ``InferenceData`` object that contains the samples. + trace : DataTree + A `DataTree` following the InferenceData schema that contains the samples. """ if low_rank_modified_mass_matrix and transform_adapt: