Add ShaplEIG: Bayesian experimental design approximator#548
Draft
david-rundel wants to merge 5 commits into
Draft
Add ShaplEIG: Bayesian experimental design approximator#548david-rundel wants to merge 5 commits into
david-rundel wants to merge 5 commits into
Conversation
…alues ShaplEIG fits a Gaussian process surrogate with a weighted Hamming product kernel on queried coalition values and sequentially selects coalitions by maximizing the closed-form expected information gain about the Shapley values. The Shapley structure is handled via elementary-symmetric-polynomial identities of the kernel's generating polynomials, so the 2^n coalition space is never enumerated. - new optional extra 'shapleig' (torch/gpytorch/botorch/linear_operator) with lazy import and an informative placeholder when not installed - shapiq/approximator/shapleig/: approximator + flat Hamming GP surrogate + self-contained ESP math core (controlled copy of the validated reference implementation, kept in sync) - order-1 SV InteractionValues output (min_order=0, baseline = empty coalition value); SV-variance return left as a commented scaffold until InteractionValues supports uncertainty information - tqdm progress bar over the BED iterations - registered in SV_APPROXIMATORS and the public API; unit tests covering output format, determinism, and convergence to ExactComputer ground truth
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
- approximate() now delegates to the public approximate_with_variance(), which returns the SV estimates together with their marginal posterior variances (diagonal of the SV posterior covariance, original game scale); verified bit-identical against the reference implementation - shapleig extra: exact pins relaxed to lower bounds (>=); the bounds are the versions the reference implementation was validated against - float64 is enforced (set/restore) only around GP model construction in _surrogate.py: parameters, priors, and constraint initial values are created in torch's default dtype before botorch casts to the data dtype, and float32 creation changes the fitted hyperparameters and selections (verified empirically); everything downstream is float64 via data/params - infeasibility guards softened to warnings: exhaustive candidates with n > 16 warn about cost; a candidate set smaller than the required iterations warns and caps the iterations (estimation_budget reports the budget actually spent) - initial design and candidate set now reuse the base class sampler (uniform size weights + pairing trick), reset to its seeded initial state before each draw, instead of constructing fresh samplers - docstrings: refit cost note, max_candidates-vs-budget guidance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
ShaplEIG fits a Gaussian process surrogate with a weighted Hamming product kernel on queried coalition values and sequentially selects coalitions by maximizing the closed-form expected information gain about the Shapley values. The Shapley structure is handled via elementary-symmetric-polynomial identities of the kernel's generating polynomials, so the 2^n coalition space is never enumerated.
Motivation and Context
This PR adds ShaplEIG, an approximator for Shapley values based on Bayesian experimental design (BED), accompanying the paper "ShaplEIG: Bayesian Experimental Design for Shapley Value Estimation" (Rundel et al., 2026).
The scope was pre-agreed with @mmschlk and @Advueu963: a black-box, end-to-end approximator in
shapiq/approximator/shapleig/behind a new optional dependency groupshapleig.Open questions for the maintainers:
shapleigextra is pinned to the exact versions used for the published experiments(
torch==2.9.1,gpytorch==1.14,botorch==0.14.0,linear_operator==0.6). Is this ok?scaffold to expose them via an
approximate_with_varianceonceInteractionValuessupportsuncertainty information (future work, as discussed).
Public API Changes
No Public API changes
Yes, Public API changes (Details below)
New public class
shapiq.ShaplEIG(=shapiq.approximator.shapleig.ShaplEIG), registered inSV_APPROXIMATORS.New optional dependency group
shapleiginpyproject.toml.No changes to any existing API.
import shapiqworks without the extra installed (the optionaldependencies are imported in
ShaplEIG.__init__, which raises anImportErrorpointing topip install 'shapiq[shapleig]'when they are missing).How Has This Been Tested?
tests/shapiq/tests_unit/tests_approximators/test_approximator_shapleig.py:output format (
InteractionValues, order 1,min_order=0, baseline = empty-coalition value),determinism under
random_state, convergence toExactComputerground truth with growing budget,budget/candidate-set validation, sampled-candidate-subset and warmstart configurations.
same game/seed/protocol (default and warmstart configurations).
[spex]testscaused by the
sparseextra not being installed locally, identical onmain).Checklist
ShaplEIGup via theshapiq.approximatorautosummary?)CHANGELOG.md(if relevant for users).