Skip to content

_DistributionTerm.expand and .to_event return _ArrayTerm instead of Distribution, breaking standard NumPyro vectorization idiom #666

@datvo06

Description

@datvo06

Splitting out from #321's comment.
_DistributionTerm.expand is @defop-annotated to return jax.Array rather than Distribution. So a method chain like Normal(mu_term, tau_term).expand([J]).to_event(1) falls through to _ArrayTerm after .expand([J]), and the subsequent .to_event(1) then raises:

AttributeError: '_ArrayTerm' object has no attribute 'to_event'

_DistributionTerm.to_event has the same @defop return-annotation issue, so even if .expand were fixed, the next .to_event call would land in the same hole.

Here is a self-contained reproduction script:

import jax, jax.numpy as jnp, jax.random as jr
import numpyro
from effectful.handlers.numpyro import Normal
from effectful.ops.syntax import defop

mu = defop(jax.Array, name="mu")

def model():
    # mu() is a free-variable Term, so Normal(mu(), 1.0) is a _DistributionTerm.
    # Chaining .expand([3]).to_event(1) then hits the @defop return-type bug:
    # .expand returns _ArrayTerm, so .to_event lands on the wrong term type.
    numpyro.sample("theta", Normal(mu(), 1.0).expand([3]).to_event(1))

numpyro.infer.MCMC(numpyro.infer.NUTS(model),
                   num_warmup=10, num_samples=10,
                   ).run(jr.PRNGKey(0))
# AttributeError: '_ArrayTerm' object has no attribute 'to_event'

The intermediate .expand([3]) result is an _ArrayTerm (because _DistributionTerm.expand's @defop return type is jax.Array), so chaining .to_event(1) lands on the wrong term type. Trace excerpt:

File ".../effectful/handlers/numpyro.py", line ...
    def expand(self, ...) -> jax.Array:   # <-- should be Distribution
        ...

This blocks the textbook NumPyro vectorised hierarchical-model idiom for any model whose loc/scale arrive as effectful terms. The workaround is the per-element loop:

for j in range(J):
    theta_j = numpyro.sample(f"theta_{j}", Normal(mu_term, tau_term))

which works but is O(J) Python overhead and prevents JAX-level vectorisation. The .expand([J]).to_event(1) form would be the natural one to write.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions