Skip to content

Remove GPJax dependency from the test suite#37

Merged
jessegrabowski merged 2 commits into
mainfrom
remove-gpjax-test-dep
Jun 12, 2026
Merged

Remove GPJax dependency from the test suite#37
jessegrabowski merged 2 commits into
mainfrom
remove-gpjax-test-dep

Conversation

@bwengals

@bwengals bwengals commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Closes #36.

Replaces the GPJax reference comparisons with self-contained correctness checks, so the suite no longer breaks on GPJax API changes (the reason #35 had to pin gpjax==0.13.6).

What changed

  • Kernels (test_stationary.py): compared against closed-form analytic references in NumPy.
  • Gaussian / Bernoulli likelihoods: closed-form VE, and an independent NumPy Gauss-Hermite quadrature.
  • SVGP Bernoulli/Poisson ELBO (test_svgp.py): reuse the numpy+scipy reference that already backs the StudentT/NegBinom tests, factored into tests/_svgp_ref.py.
  • Kernel fix (stationary.py): Matern52/32 used pt.sqrt(5.0), which autocasts the literal to float32 regardless of floatX. Typed via pt.constant(..., dtype=floatX) so it follows the backend; Matern now matches the reference to ~3e-15. Other kernel literals are float32-exact, so no other changes.
  • Env: drop gpjax (plus now-unused jax/jaxlib/optax).

Test plan

  • pytest tests/ — 265 passed
  • pre-commit run --all-files clean
  • scripts/run_mypy.py — 37/37 pass

bwengals added 2 commits June 11, 2026 00:37
Replace the GPJax reference comparisons with self-contained correctness
checks so the suite no longer breaks on GPJax API changes (e.g. the 0.15
gram-operator change that required pinning gpjax==0.13.6).

- Kernels: compare ExpQuad/Matern52/32 against closed-form analytic
  references in numpy. Matern uses a looser tolerance to absorb the
  float32 sqrt(5)/sqrt(3) constant pytensor evaluates in floatX.
- Gaussian likelihood: closed-form variational expectation.
- Bernoulli likelihood: independent numpy Gauss-Hermite quadrature.
- SVGP Bernoulli/Poisson ELBO: reuse the numpy+scipy reference machinery
  (factored into tests/_svgp_ref.py and shared with the StudentT/NegBinom
  tests), as the scipy-ref module already anticipated.
- Drop gpjax from conda_envs/environment-test.yaml.

Closes #36.
… env

- Matern52/32 used pt.sqrt(5.0)/pt.sqrt(3.0), which autocasts the literal to
  float32 regardless of config.floatX, injecting ~1e-7 error into an otherwise
  float64 kernel. Use pt.constant(..., dtype=config.floatX) so the constant
  follows the backend's precision; Matern now matches the analytic reference to
  ~3e-15 in float64. Tighten the kernel test tolerance accordingly.
- ptgp, and the test suite after the GPJax removal, never import jax/optax, so
  drop jax/jaxlib/optax from conda_envs/environment-test.yaml.
@bwengals bwengals requested a review from jessegrabowski June 12, 2026 08:46
@jessegrabowski jessegrabowski added tests Anything related to the test suite dependencies Adding/removing package dependencies noreleasenote Don't include this PR in new release notes labels Jun 12, 2026
@jessegrabowski jessegrabowski merged commit 50e22df into main Jun 12, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Adding/removing package dependencies noreleasenote Don't include this PR in new release notes tests Anything related to the test suite

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove GPJax dependency from the test suite

2 participants