Skip to content

jax==0.8.2 breaks is_jax_array and subsequently device #368

@amacati

Description

@amacati

Jax just released its version 0.8.2 https://github.com/jax-ml/jax/releases/tag/jax-v0.8.2, and it causes is_jax_array to no longer work for jit compiled arrays. It's probably caused be this change:

jax's Tracer no longer inherits from jax.Array at runtime. However, jax.Array now uses a custom metaclass such isinstance(x, Array) is true if an object x represents a traced Array. Only some Tracers represent Arrays, so it is not correct for Tracer to inherit from Array.

Example

import array_api_compat as xpc
import jax
import jax.numpy as jnp


@jax.jit
def fn(x):
    print(xpc.is_jax_array(x))  # False
    return jnp.zeros(x.shape, device=xpc.device(x))


x = jnp.array([1.0, 2.0, 3.0])
print(xpc.is_jax_array(x))  # True
y = fn(x)

which yields

Traceback (most recent call last):
  File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/jax/_src/core.py", line 1071, in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'device'. Did you mean: 'to_device'?

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 28, in <module>
    y = fn(x)
  File "/home/mschuck/repos/proto/proto/rotation_grad.py", line 24, in fn
    return jnp.zeros(x.shape, device=xpc.device(x))
                                     ~~~~~~~~~~^^^
  File "/home/mschuck/repos/proto/.venv/lib/python3.14/site-packages/array_api_compat/common/_helpers.py", line 764, in device
    return x.device  # pyright: ignore
           ^^^^^^^^
AttributeError: DynamicJaxprTracer has no attribute device. Did you mean: 'devices'?
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Relevant Issues

jax-ml/jax#26000
scipy/scipy#22246

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions