Status: experimental (pre-alpha). API and semantics may change without notice. The package depends on
jax._srcinternals; the JAX version pin may tighten or widen as JAX evolves. Install withpip install jet-array --pre, sincepipdoes not pick up pre-release versions by default.
Array-form Taylor-mode automatic differentiation in JAX.
jet_array propagates a truncated Taylor polynomial of order K through an
arbitrary JAX computation in a single forward pass, returning all K Taylor
coefficients of the output as one leading-axis array. This is the higher-order
analogue of jax.jvp: where jvp takes one tangent and returns one
derivative, jet takes a series of K coefficients and returns the full
order-K Taylor expansion.
The package extends jax.experimental.jet
(Bettencourt, Johnson, Duvenaud 2019) in two ways:
- Array storage. Series are stored as a single
jnp.ndarrayalong axis 0 instead of a Python tuple. This makes the order axis a first-class JAX dimension that you canjit,vmap, andscanover. effective_orderparameter. A runtime hint that lets a single jitted program execute with a dynamic truncation depth — useful when each element of avmap-batched call needs a different number of coefficients.
Taylor-mode (jet) gives you the first K derivatives of a scalar function
in O(K) work. Computing the same K derivatives by repeatedly nesting
jax.grad costs O(K²) (the depth of the trace doubles each time). So
Taylor-mode is the right tool when:
- You need many derivatives at one point — for example,
K=20to evaluate a Bell polynomial, an ODE Taylor integrator, or a high-order moment. - The program you are differentiating contains the same primitives many
times — Taylor-mode amortizes one trace, the nested-grad approach traces
Ktimes. - You need the full Taylor polynomial (with factorial-divided
coefficients) rather than the unscaled derivatives — for example, to
evaluate
f(x + h)as a series inh.
For first or second derivatives, ordinary jax.grad / jax.jacrev is
faster — Taylor-mode pays only when K ≥ ~3.
pip install jet-array --prejet_array uses jax._src internals and is currently tested on
JAX 0.8.x, 0.9.x, and 0.10.x. The pin is jax>=0.8,<0.11.
jet_array.jet(fun, primals, series, effective_order=None) -> (primal_out, series_out)fun— a JAX-traceable callable.primals— a tuple of input primal values, one per positional argument offun. Each must be a leaf (scalar or array, not a pytree).series— a tuple of arrays, one per primal.series[i][k-1]is thek-th Taylor coefficient of thei-th input along the path you are expanding around. The trailing dimensions ofseries[i]must matchprimals[i]. The orderKof the expansion is the leading axis length; it must be the same for all primals.effective_order— optional. See below.- Returns:
(primal_out, series_out)whereprimal_out = fun(*primals)andseries_out[k-1]is thek-th Taylor coefficient of the output, in the same convention as the input.
series[k-1] is the Taylor coefficient: if f is being expanded along
x(t) = x₀ + s₁·t + s₂·t² + …, then
f(x(t)) = f(x₀) + Σ_{k≥1} series_out[k-1] · t^k.
When you set series_in = (1, 0, 0, …, 0) and a single primal, this reduces
to expanding f directly around x₀, and the relationship to derivatives
is
series_out[k-1] = f^(k)(x₀) / k!.
To recover the unscaled k-th derivative, multiply by k!. This is the
convention used in Evaluating Derivatives (Griewank & Walther, §13).
jax.experimental.jet exposes both conventions through its
factorial_scaled keyword argument (defaults to True, returning
derivative coefficients f^(k)(x₀)). Calling
jax.experimental.jet.jet(..., factorial_scaled=False) returns the same
Taylor coefficients as jet_array. The equivalence is tested
coefficient-by-coefficient in tests/test_against_jax_jet.py.
import math
import jax
import jax.numpy as jnp
from jet_array import jet
jax.config.update("jax_enable_x64", True)
def f(x):
return jnp.exp(jnp.sin(x))
x0 = jnp.asarray(0.5)
K = 8 # expansion order
series_in = jnp.zeros(K).at[0].set(1.0) # direction: x(t) = x0 + t
primal, series = jet(f, (x0,), (series_in,))
# series[k-1] = f^(k)(x0) / k!
# Multiply by k! to get the unscaled k-th derivative.
for k in range(1, K + 1):
print(f"f^({k})(x0) = {float(series[k-1]) * math.factorial(k):.6f}")def g(x, y):
return jnp.exp(x * y)
x0 = jnp.asarray(0.5)
y0 = jnp.asarray(0.3)
# Expand along the path (x0 + t, y0): output series gives ∂^k g/∂x^k / k!
sx = jnp.zeros(4).at[0].set(1.0)
sy = jnp.zeros(4) # y stays constant
primal, series = jet(g, (x0, y0), (sx, sy))
# series[0] = ∂g/∂x = y₀ · exp(x₀ y₀)
# series[1] = ∂²g/∂x² / 2! = y₀² / 2 · exp(x₀ y₀)
# Diagonal direction (x0+t, y0+t): mixed derivatives appear.
sy_diag = jnp.zeros(4).at[0].set(1.0)
_, series_diag = jet(g, (x0, y0), (sx, sy_diag))
# series_diag[k-1] = (1/k!) · sum over multi-indices |α|=k of ∂^α g · α-coefficientsjet is fully traceable. The order K is part of the input shape, so a
single jit-compiled program handles any computation at that order:
@jax.jit
def taylor(x0):
series_in = jnp.zeros(8).at[0].set(1.0)
return jet(f, (x0,), (series_in,))
primal, series = taylor(jnp.asarray(0.5))If you call taylor with a different K, JAX retraces (because the input
shape changed). To avoid that — see effective_order below.
xs = jnp.linspace(0.0, 1.0, 100)
series_in = jnp.zeros(5).at[0].set(1.0)
primals, series = jax.vmap(
lambda x: jet(f, (x,), (series_in,))
)(xs)
# primals.shape == (100,), series.shape == (100, 5)effective_order is an integer JAX scalar that tells jet how many
coefficients of the output you actually intend to use. The output array
still has shape (K, ...) — same as without the parameter — but jet
skips work computing high-order entries inside primitives whose Taylor
rules support the hint (currently exp, expm1, log, log1p, pow,
logistic, tanh, erf_inv, div).
The value of entries beyond effective_order is unspecified; treat them
as garbage. The point is to avoid recompilation when the order varies
across calls or across a vmap-batched dimension, since the array shape
is fixed.
The expansion order K is part of the input shape, so a function
jet'd at one K re-traces and re-compiles for any other K. If the
order you actually need varies across calls — different time steps in
an integrator, different inputs in a batch, different convergence
checks in an iterative method — you would otherwise face two bad
options:
- Pad every call to the maximum possible order. Wastes compute on coefficients you don't use.
- Re-trace per order. Fatal whenever the order changes more often than the cost of compilation.
effective_order resolves both: compile once at K = max possible order and pay per-call work proportional to that call's actual order.
import jax
import jax.numpy as jnp
from jet_array import jet
K_MAX = 16 # static array size
@jax.jit
def taylor(x, k_dyn):
"""Taylor series of exp(sin(x)) at x, computed up to order k_dyn,
with a fixed array size of K_MAX so a single XLA program handles
every order."""
series_in = jnp.zeros(K_MAX).at[0].set(1.0)
return jet(lambda x: jnp.exp(jnp.sin(x)), (x,), (series_in,),
effective_order=k_dyn)
# Both calls hit the same compiled program; the second does more work.
p1, s1 = taylor(jnp.asarray(0.5), jnp.array(4)) # uses s1[:4]
p2, s2 = taylor(jnp.asarray(0.5), jnp.array(12)) # uses s2[:12]For batches where each element wants its own order, pass an array of
effective_order values and combine with jax.vmap:
xs = jnp.linspace(0.0, 1.0, 100)
ks = jnp.array([3, 5, 8, 12] * 25) # one per batch element
batched = jax.vmap(taylor)(xs, ks)
# batched.series[i, :ks[i]] is the answer for element i.- Slice the result to
series_out[:effective_order]. Entries above the threshold are unspecified. effective_orderworks throughjax.jitandjnp.*automatically. No need to rewrite your function inlax.*style.- The savings come from short-circuiting the
O(K²)convolutions in nonlinear rules. Linear ops (add,broadcast,reshape, …) are alreadyO(K)and unaffected. - The win is largest when the dominant cost is in the convolution-style loops that the hint short-circuits. For shallow programs the overhead of the conditional may exceed the savings.
jax.experimental.jet stores the series as a Python tuple. That works for
small K known at trace time but causes friction otherwise:
- The order axis is invisible to JAX — you cannot
vmaporscanover it. - A tuple of length
KproducesKseparate jaxpr equations per primitive, so trace time grows linearly inKeven when the computation is constant per coefficient. effective_ordermakes no sense for a Python tuple.
Storing series along axis 0 of a single array makes all of these go away.
The cost is some boilerplate inside primitive rules to convolve along that
axis, which jet_array provides.
Custom Taylor rules are provided for the JAX primitives that arise in
typical scientific code: arithmetic and broadcasting, exp, expm1,
log, log1p, sin, cos, sinh, cosh, tanh, logistic, erf,
erf_inv, pow, square, sqrt, div, dynamic_slice,
dynamic_update_slice, and cumsum/cumprod. Linear and zero-derivative
primitives are handled generically. For untraced primitives the rule
falls back to the standard convolution propagator.
The full correctness suite checks every supported primitive against
jax.experimental.jet at multiple expansion points and orders up to 20.
- Pytree primal inputs are not supported (each primal must be a leaf).
- The package depends on
jax._srcinternals; currently tested on JAX 0.8.x, 0.9.x, and 0.10.x.
If you use jet-array in academic work, please cite the repository:
@software{jet_array,
title = {jet-array: array-form Taylor-mode automatic differentiation in JAX},
author = {Yang, Cambridge},
year = {2026},
url = {https://github.com/thisiscam/jet-array},
}The underlying Taylor-mode algorithm is from:
@inproceedings{bettencourt2019taylor,
title = {Taylor-mode automatic differentiation for higher-order derivatives in JAX},
author = {Bettencourt, Jesse and Johnson, Matthew J. and Duvenaud, David},
booktitle = {NeurIPS Program Transformations Workshop},
year = {2019},
}Apache-2.0. Portions derived from JAX (Apache-2.0, Copyright 2020 The JAX Authors); copyright headers preserved per Apache-2.0 §4.