Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ compiled_model = nutpie.compile_pymc_model(pymc_model)
trace_pymc = nutpie.sample(compiled_model)
```

`trace_pymc` now contains an ArviZ `InferenceData` object, including sampling
`trace_pymc` now contains a `DataTree` object, including sampling
statistics and the posterior of the variables defined above.

We can also control the sampler in a non-blocking way:
Expand All @@ -111,7 +111,7 @@ sampler.resume()
# Wait for the sampler to finish (up to timeout seconds)
sampler.wait(timeout=0.1)
# Note that not passing any timeout to `wait` will
# wait until the sampler finishes, then return the InferenceData object:
# wait until the sampler finishes, then return the DataTree object:
idata = sampler.wait()

# or we can also abort the sampler (and return the incomplete trace)
Expand Down
4 changes: 2 additions & 2 deletions docs/pymc-usage.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ information about how each chain is doing:
(100 to 1000) are a sign that the parameterization of the model is not ideal,
and the sampler is very inefficient.

After sampling, this returns an `arviz` InferenceData object that you can use to
After sampling, this returns a `DataTree` object that follows the InferenceData schema, which you can use to
analyze the trace.

For example, we should check the effective sample size:
Expand All @@ -110,7 +110,7 @@ az.ess(trace)
and take a look at a trace plot:

```{python}
az.plot_trace(trace);
az.plot_trace_dist(trace);
```

### Choosing the backend
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"arro3-core >= 0.6.0",
"pandas >= 2.0",
"xarray >= 2025.01.2",
"arviz >= 0.20.0,<1.0",
"arviz >= 1.0",
"obstore >= 0.8.0",
"zarr >= 3.1.0",
]
Expand Down
2 changes: 1 addition & 1 deletion python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def compile_pymc_model(
freeze_model : bool | None
Freeze all dimensions and shared variables to treat them as compile time
constants.

Returns
-------
compiled_model : CompiledPyMCModel
Expand Down
22 changes: 13 additions & 9 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
import pyarrow
from xarray import DataTree

from nutpie import _lib # type: ignore

Expand Down Expand Up @@ -96,11 +97,14 @@ def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs):
stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars
)

data = {
"posterior": data_posterior,
"sample_stats": stats_posterior,
"warmup_posterior": data_tune,
"warmup_sample_stats": stats_tune,
}
return arviz.from_dict(
data_posterior,
sample_stats=stats_posterior,
warmup_posterior=data_tune,
warmup_sample_stats=stats_tune,
data,
dims=dims,
**kwargs,
)
Expand Down Expand Up @@ -639,7 +643,7 @@ def sample(
progress_style: str | None = None,
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
) -> arviz.InferenceData: ...
) -> DataTree: ...


@overload
Expand All @@ -663,7 +667,7 @@ def sample(
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
**kwargs,
) -> arviz.InferenceData: ...
) -> DataTree: ...


@overload
Expand Down Expand Up @@ -710,7 +714,7 @@ def sample(
progress_rate: int = 100,
zarr_store: _ZarrStoreType | None = None,
**kwargs,
) -> arviz.InferenceData | _BackgroundSampler:
) -> DataTree | _BackgroundSampler:
"""Sample the posterior distribution for a compiled model.

Parameters
Expand Down Expand Up @@ -804,8 +808,8 @@ def sample(

Returns
-------
trace : arviz.InferenceData
An ArviZ ``InferenceData`` object that contains the samples.
trace : DataTree
A `DataTree` following the InferenceData schema that contains the samples.
"""

if low_rank_modified_mass_matrix and transform_adapt:
Expand Down
Loading