[plotting] Allow user to specify type of posterior data visualisation#680
[plotting] Allow user to specify type of posterior data visualisation#680AllenDowney wants to merge 4 commits intomainfrom
Conversation
- Add kind parameter: 'ribbon', 'histogram', 'spaghetti' - Add ci_kind parameter: 'hdi' or 'eti' (default 'hdi') - Add ci_prob parameter (default 0.94, matching current behavior) - Add num_samples parameter for spaghetti plots - Implement ribbon plots with HDI and ETI support - Implement histogram visualization (basic) - Implement spaghetti plot visualization - Maintain backward compatibility with hdi_prob parameter Addresses #671
- Add histogram visualization (2D heatmap) with global y-bins - Add spaghetti plot visualization with configurable num_samples - Add ETI (Equal-Tailed Interval) support in addition to HDI - Add comprehensive test suite for all visualization types - Maintain backward compatibility with hdi_prob parameter - Update BaseExperiment.plot() to pass through new parameters Addresses #671
The badge is auto-generated by pre-commit and should not be committed. Addresses #671
|
Thanks @AllenDowney ! Hoping to look at this very soon. Just flagging up that there might be some conflicts to resolve if #643 gets merged first. That one allows the user to decide if they are looking at things related to the posterior expectation or the posterior predictive. So these PR's are nicely complementary, but it's possible there could be some code overlap/conflicts. |
There was a problem hiding this comment.
Thanks for this — the API direction looks promising. I reviewed this mainly from a maintainability/reviewability perspective and have a few blocking + process suggestions.
Blocking issues
- CI currently fails early due to import error in
causalpy/plot_utils.py:ImportError: cannot import name 'eti' from 'arviz_stats'
- Because this import fails at module load, reviewers cannot run plotting tests or examples yet.
Requested updates before final review
- Fix ETI import/implementation path so test/docs/sdist are green.
- Add visual review artifacts to the PR description/comments, since GitHub diff does not show rendered plots:
kind="ribbon"withci_kind="hdi"kind="ribbon"withci_kind="eti"kind="histogram"kind="spaghetti"
- Include one minimal reproducible script/snippet (not notebook-only) to generate mock posterior data and produce all plot kinds. This allows reviewers to validate quickly without rerunning full notebooks.
API/compatibility checks to confirm
plot_xY()now returns either(Line2D, PolyCollection)or(list[Line2D], None)depending onkind.- Please confirm experiment-level
.plot()legend handling remains consistent for non-ribbonkinds, since several call sites still assume tuple-style handles. - Please clarify whether
hdi_probdeprecation timeline is planned, or if it remains indefinitely as compatibility alias.
Efficient review request
Could you add 3–5 static PNGs (before/after where useful) generated from the same mock posterior dataset and style settings? That will make visual review possible without manual notebook execution.
Suggested mock posterior script for reviewers
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from causalpy.plot_utils import plot_xY
# Reproducible synthetic posterior: (chain, draw, obs_ind)
rng = np.random.default_rng(42)
n_chains, n_draws, n_t = 2, 200, 40
x = pd.date_range("2022-01-01", periods=n_t, freq="D")
trend = 10 + 0.05 * np.arange(n_t) + 0.01 * np.arange(n_t) ** 2
samples = np.empty((n_chains, n_draws, n_t))
for c in range(n_chains):
for d in range(n_draws):
draw_mean = trend + rng.normal(0, 0.4, n_t)
samples[c, d, :] = draw_mean + rng.normal(0, 0.8, n_t)
Y = xr.DataArray(
samples,
dims=["chain", "draw", "obs_ind"],
coords={"chain": np.arange(n_chains), "draw": np.arange(n_draws), "obs_ind": x},
)
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
plot_xY(x, Y, ax=axes[0, 0], kind="ribbon", ci_kind="hdi", ci_prob=0.94, label="HDI")
axes[0, 0].set_title("Ribbon (HDI)")
plot_xY(x, Y, ax=axes[0, 1], kind="ribbon", ci_kind="eti", ci_prob=0.94, label="ETI")
axes[0, 1].set_title("Ribbon (ETI)")
plot_xY(x, Y, ax=axes[1, 0], kind="histogram", label="Histogram")
axes[1, 0].set_title("Histogram")
plot_xY(x, Y, ax=axes[1, 1], kind="spaghetti", num_samples=60, label="Spaghetti")
axes[1, 1].set_title("Spaghetti")
for ax in axes.ravel():
ax.legend(loc="best")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()Conflict check with #643
I checked this directly with a branch-to-branch merge simulation.
There are real merge conflicts between #643 and #680 in these files:
causalpy/experiments/base.pycausalpy/plot_utils.pycausalpy/tests/test_plot_utils.py
So yes, conflict risk is concrete, not just theoretical.
Best sequencing: merge #643 first, then rebase/update #680 on top and re-run tests.
|
Including images is optional - especially if there's a temp script to generate plots for dev/review purposes. If the conflicts become gnarly when #643 is merged, then I'm happy to give that a stab - I feel guilty about that kind of thing :) |
|
Thanks for the nudge — adding a minimal reproducible script here to speed visual/API review without needing notebook execution. This script creates synthetic posterior draws and renders all current plot kinds:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from causalpy.plot_utils import plot_xY
# Reproducible synthetic posterior: (chain, draw, obs_ind)
rng = np.random.default_rng(42)
n_chains, n_draws, n_t = 2, 200, 40
x = pd.date_range("2022-01-01", periods=n_t, freq="D")
trend = 10 + 0.05 * np.arange(n_t) + 0.01 * np.arange(n_t) ** 2
samples = np.empty((n_chains, n_draws, n_t))
for c in range(n_chains):
for d in range(n_draws):
draw_mean = trend + rng.normal(0, 0.4, n_t)
samples[c, d, :] = draw_mean + rng.normal(0, 0.8, n_t)
Y = xr.DataArray(
samples,
dims=["chain", "draw", "obs_ind"],
coords={"chain": np.arange(n_chains), "draw": np.arange(n_draws), "obs_ind": x},
)
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
plot_xY(x, Y, ax=axes[0, 0], kind="ribbon", ci_kind="hdi", ci_prob=0.94, label="HDI")
axes[0, 0].set_title("Ribbon (HDI)")
plot_xY(x, Y, ax=axes[0, 1], kind="ribbon", ci_kind="eti", ci_prob=0.94, label="ETI")
axes[0, 1].set_title("Ribbon (ETI)")
plot_xY(x, Y, ax=axes[1, 0], kind="histogram", label="Histogram")
axes[1, 0].set_title("Histogram")
plot_xY(x, Y, ax=axes[1, 1], kind="spaghetti", num_samples=60, label="Spaghetti")
axes[1, 1].set_title("Spaghetti")
for ax in axes.ravel():
ax.legend(loc="best")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()If useful, I can follow up with static PNGs from this same dataset/style setup. |
Sync PR #680 with the current main branch so the author is not blocked by stale conflicts or outdated workflow changes. Replace the ETI helper import with a local quantile-based interval calculation so plotting, docs, and package builds keep working against current dependencies. Made-with: Cursor
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #680 +/- ##
==========================================
+ Coverage 93.44% 93.55% +0.10%
==========================================
Files 74 74
Lines 11199 11414 +215
Branches 657 676 +19
==========================================
+ Hits 10465 10678 +213
- Misses 544 545 +1
- Partials 190 191 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|

Summary
This draft PR extends the plotting capabilities of CausalPy to support multiple visualization types for posterior data. The purpose of this draft PR is to discuss the API design and approach, not to finalize implementation or testing.
Currently, CausalPy only supports CI ribbon visualizations using Highest Density Intervals (HDI). This PR adds:
The API design aligns with ArviZ's naming conventions (
ci_prob,ci_kind) while maintaining backward compatibility with existing code.Fixes #671
Changes
Extended
plot_xY()function (causalpy/plot_utils.py):kindparameter:"ribbon","histogram", or"spaghetti"(default:"ribbon")ci_kindparameter:"hdi"or"eti"(default:"hdi"to match current behavior)ci_probparameter (default:0.94to match current behavior)num_samplesparameter for spaghetti plots (default:50)hdi_probparameterUpdated
BaseExperiment.plot()method (causalpy/experiments/base.py):_bayesian_plot()and_ols_plot()methodsTesting
causalpy_test.py) demonstrating all visualization typesAPI Design Rationale
kindparameter: Uses familiar naming convention from seaborn/pandasci_probandci_kindnaming: Aligns with ArviZ's naming conventions for ecosystem consistencyhdi_probcontinues to work unchangedOpen Questions for Reviewers
_bayesian_plot()methods) to explicitly accept and pass through these parameters, or is passing via**kwargssufficient?Checklist
📚 Documentation preview 📚: https://causalpy--680.org.readthedocs.build/en/680/