Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions pytensor/tensor/random/rewriting/numba.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.traversal import applys_between
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.basic import cast
from pytensor.tensor.random.basic import (
InvGammaRV,
LogNormalRV,
MvNormalRV,
NormalRV,
)
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.numba import simplify_core_shape_graphs
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.variable import TensorConstant


_RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64 = frozenset(
(NormalRV, LogNormalRV, InvGammaRV, MvNormalRV)
)


@node_rewriter([RandomVariable])
Expand Down Expand Up @@ -101,3 +115,73 @@ def introduce_explicit_core_shape_rv(fgraph, node):
"numba",
position=100,
)


@node_rewriter([RandomVariable])
def cast_rv_float_params_to_float64(fgraph, node):
"""Cast non-float64 floating-point RandomVariable parameters to float64.

For a few distributions, numba's sampler runs markedly faster -- or only compiles
Comment thread
ricardoV94 marked this conversation as resolved.
at all -- when its float parameters are float64 rather than float32 / integer. The
declared ``dtype`` (and ``floatX``) is only a final cast on the output, never the
sampling precision, so upcasting the parameters once here is a performance fix, not
a semantic one: the draw is still cast to the declared dtype on store, exactly as
``RandomVariable.perform`` does (it hands parameters straight to NumPy and narrows
only the output via ``np.asarray(..., dtype=self.dtype)``).

Which RVs actually benefit is not predictable from "the generator samples in
float64" -- it depends on the specific numba implementation and was measured per
distribution; see ``_RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64`` for the opt-in list. For
``normal``, the ``int8`` literals of ``normal(0, 1)`` becoming float64 is a ~3x win.

The target is float64, *not* the output dtype: under ``floatX="float32"`` casting to
the output dtype would leave the parameter float32 and still mismatched, so it would
not help. This is why the rewrite is Numba-specific -- JAX can sample natively in
float32, where forcing float64 parameters would pessimize.
"""
op = node.op
if type(op) not in _RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64:
# Opt-in narrowing. The rewrite is registered on the ``RandomVariable`` base
# class (one cheap isinstance per graph node) and the exact-type set membership
# here only runs for the few actual RV nodes, bailing on all but the listed ones.
return None

if config.warn_float64 != "ignore":
# The user asked to be warned about / forbidden float64; don't introduce it.
return None

dist_params = op.dist_params(node)
# Every parameter of these (continuous) RVs is real-valued, so upcast any that is not
# already float64 -- ``normal(0, 1)``'s int8 literals included.
cast_idxs = [
i for i, param in enumerate(dist_params) if param.type.dtype != "float64"
]
if not cast_idxs:
return None

new_params = list(dist_params)
for i in cast_idxs:
param = dist_params[i]
if isinstance(param, TensorConstant):
# Fold the cast into the constant so the common ``normal(0, 1)`` case ends
# up with float64 literals rather than a leftover Cast node.
new_params[i] = constant(param.data.astype("float64"))
else:
new_params[i] = cast(param, "float64")

new_outputs = op.make_node(
op.rng_param(node), op.size_param(node), *new_params
).outputs
copy_stack_trace(node.outputs, new_outputs)
return new_outputs


optdb.register(
cast_rv_float_params_to_float64.__name__,
dfs_rewriter(cast_rv_float_params_to_float64),
"numba",
# After stabilize (1.5), before specialize (2): a one-shot sweep that inserts the
# casts, leaving specialize's equilibrium (local_cast_cast + constant folding) to
# clean them up, and running before the core-shape wrap at position 100.
position=1.9,
)
20 changes: 20 additions & 0 deletions tests/benchmarks/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest

from pytensor import config, function, shared


@pytest.mark.parametrize("floatX", ("float64", "float32"))
def test_normal_rv_benchmark_numba(floatX, benchmark):
# Drawing standard normals through numba. The ``0, 1`` literals are typed int8, and
# numba's rng.normal samples in float64; without cast_rv_float_params_to_float64
# upcasting the parameters once, each of the >100k draws pays a per-element cast and
# the function runs ~3x slower. This benchmark tracks that hot path.
with config.change_flags(floatX=floatX):
rng = shared(np.random.default_rng(0))
next_rng, draws = rng.normal(0, 1, size=(2160, 50))
fn = function(
[], draws, updates={rng: next_rng}, mode="NUMBA", trust_input=True
)
fn() # compile / warm up before timing
benchmark(fn)
60 changes: 60 additions & 0 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import scipy.stats as stats

import pytensor
import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr
from pytensor import shared
Expand Down Expand Up @@ -771,3 +772,62 @@ def rng_fn(self, rng, value, size=None):
assert large_sample.shape == (1000,)
np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2)
np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2)


def _compiled_rv_dist_param_dtypes(out):
"""Dtypes of the distribution parameters of the single RV in the compiled graph.

The numba backend wraps RVs in a ``RandomVariableWithCoreShape`` whose inputs are
``(core_shape, rng, size, *dist_params)``, so the parameters are ``inputs[3:]``.
"""
fn = function([], out, mode=numba_mode)
[node] = [
n
for n in fn.maker.fgraph.toposort()
if isinstance(n.op, RandomVariableWithCoreShape)
]
return [inp.type.dtype for inp in node.inputs[3:]]


@pytest.mark.parametrize("floatX", ["float64", "float32"])
def test_rv_float_params_cast_to_float64(floatX):
# cast_rv_float_params_to_float64 upcasts the float params of opt-in RVs to float64
# (always float64, not the output dtype -- float32 would stay mismatched) and leaves
# every other RV untouched.
with pytensor.config.change_flags(floatX=floatX):
rng = shared(np.random.default_rng(123))

# Opt-in: the int8 literals 0, 1 become float64 regardless of floatX.
assert _compiled_rv_dist_param_dtypes(
pt.random.normal(0, 1, size=(5,), rng=rng)
) == ["float64", "float64"]

# Opt-in, and required: numba's np.dot rejects a float32 covariance, so MvNormal
# only compiles once mean/cov are float64.
assert _compiled_rv_dist_param_dtypes(
pt.random.multivariate_normal(
np.zeros(3, dtype="float32"), np.eye(3, dtype="float32"), rng=rng
)
) == ["float64", "float64"]

# Not opt-in (no measurable numba speedup): the float32 probability is left as-is,
# as is the integer count ``n``.
assert _compiled_rv_dist_param_dtypes(
pt.random.binomial(n=np.int64(10), p=np.float32(0.3), size=(5,), rng=rng)
) == ["int64", "float32"]

# Not opt-in and data-following: upcasting Permutation's array would change the
# output dtype and pointlessly widen the sampled data.
assert _compiled_rv_dist_param_dtypes(
pt.random.permutation(np.arange(5, dtype="float32"), rng=rng)
) == ["float32"]


def test_rv_float_params_cast_respects_warn_float64():
# When the user asked to be warned/raised on float64, the rewrite must not silently
# introduce it; the int8 params are left as-is.
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
rng = shared(np.random.default_rng(123))
assert _compiled_rv_dist_param_dtypes(
pt.random.normal(0, 1, size=(5,), rng=rng)
) == ["int8", "int8"]
15 changes: 9 additions & 6 deletions tests/tensor/random/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,9 +1016,12 @@ def test_sidestep_unused_rng_consumer_with_duplicate_node():

def test_unused_rng():
rng = random_generator_type("rng")
next_rng, x = rng.normal([0], [1], size=3)
next_rng, _y = next_rng.normal(x.ones_like(), [1])
final_rng, z = next_rng.normal(1, 2)
# float64 parameters throughout, so the NUMBA float64-cast rewrite is a no-op and
# the strict graph comparison below need not account for inserted casts. Scalar
# Python floats autocast to float32, so use np.float64 for the scalar loc/scale.
next_rng, x = rng.normal([0.0], [1.0], size=3)
next_rng, _y = next_rng.normal(x.ones_like(), [1.0])
final_rng, z = next_rng.normal(np.float64(1.0), np.float64(2.0))

fn = function([rng], [final_rng, z], mode=get_default_mode().excluding("inplace"))

Expand Down Expand Up @@ -1050,9 +1053,9 @@ def test_unused_rng():
if config.mode != "FAST_COMPILE":
# Strict graph comparison (ones_like gets constant-folded outside FAST_COMPILE)
rng.tag.used = False # Avoid reuse warnings
next_rng, _x = rng.normal([0], [1], size=3)
next_rng, _y = next_rng.normal([1.0, 1.0, 1.0], [1])
final_rng, z = next_rng.normal(1, 2)
next_rng, _x = rng.normal([0.0], [1.0], size=3)
next_rng, _y = next_rng.normal([1.0, 1.0, 1.0], [1.0])
final_rng, z = next_rng.normal(np.float64(1.0), np.float64(2.0))

expected = [final_rng, z]

Expand Down
Loading