Splitting out from #321's comment.
_DistributionTerm.expand is @defop-annotated to return jax.Array rather than Distribution. So a method chain like Normal(mu_term, tau_term).expand([J]).to_event(1) falls through to _ArrayTerm after .expand([J]), and the subsequent .to_event(1) then raises:
AttributeError: '_ArrayTerm' object has no attribute 'to_event'
_DistributionTerm.to_event has the same @defop return-annotation issue, so even if .expand were fixed, the next .to_event call would land in the same hole.
Here is a self-contained reproduction script:
import jax, jax.numpy as jnp, jax.random as jr
import numpyro
from effectful.handlers.numpyro import Normal
from effectful.ops.syntax import defop
mu = defop(jax.Array, name="mu")
def model():
# mu() is a free-variable Term, so Normal(mu(), 1.0) is a _DistributionTerm.
# Chaining .expand([3]).to_event(1) then hits the @defop return-type bug:
# .expand returns _ArrayTerm, so .to_event lands on the wrong term type.
numpyro.sample("theta", Normal(mu(), 1.0).expand([3]).to_event(1))
numpyro.infer.MCMC(numpyro.infer.NUTS(model),
num_warmup=10, num_samples=10,
).run(jr.PRNGKey(0))
# AttributeError: '_ArrayTerm' object has no attribute 'to_event'
The intermediate .expand([3]) result is an _ArrayTerm (because _DistributionTerm.expand's @defop return type is jax.Array), so chaining .to_event(1) lands on the wrong term type. Trace excerpt:
File ".../effectful/handlers/numpyro.py", line ...
def expand(self, ...) -> jax.Array: # <-- should be Distribution
...
This blocks the textbook NumPyro vectorised hierarchical-model idiom for any model whose loc/scale arrive as effectful terms. The workaround is the per-element loop:
for j in range(J):
theta_j = numpyro.sample(f"theta_{j}", Normal(mu_term, tau_term))
which works but is O(J) Python overhead and prevents JAX-level vectorisation. The .expand([J]).to_event(1) form would be the natural one to write.
Splitting out from #321's comment.
_DistributionTerm.expandis@defop-annotated to returnjax.Arrayrather thanDistribution. So a method chain likeNormal(mu_term, tau_term).expand([J]).to_event(1)falls through to_ArrayTermafter.expand([J]), and the subsequent.to_event(1)then raises:_DistributionTerm.to_eventhas the same@defopreturn-annotation issue, so even if.expandwere fixed, the next.to_eventcall would land in the same hole.Here is a self-contained reproduction script:
The intermediate
.expand([3])result is an_ArrayTerm(because_DistributionTerm.expand's@defopreturn type isjax.Array), so chaining.to_event(1)lands on the wrong term type. Trace excerpt:This blocks the textbook NumPyro vectorised hierarchical-model idiom for any model whose loc/scale arrive as effectful terms. The workaround is the per-element loop:
which works but is O(J) Python overhead and prevents JAX-level vectorisation. The
.expand([J]).to_event(1)form would be the natural one to write.