From 7a6610c158b5a375586b7eb07b3e2e5ac8330c0c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 14 Jun 2026 00:41:01 +0200 Subject: [PATCH] Numba: cast float RV parameters to float64 to speed up sampling --- pytensor/tensor/random/rewriting/numba.py | 84 +++++++++++++++++++++ tests/benchmarks/test_random.py | 20 +++++ tests/link/numba/test_random.py | 60 +++++++++++++++ tests/tensor/random/rewriting/test_basic.py | 15 ++-- 4 files changed, 173 insertions(+), 6 deletions(-) create mode 100644 tests/benchmarks/test_random.py diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 0458929ee1..52d7314b2c 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -1,11 +1,25 @@ from pytensor.compile import optdb +from pytensor.configdefaults import config from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.graph.traversal import applys_between from pytensor.tensor import as_tensor, constant +from pytensor.tensor.basic import cast +from pytensor.tensor.random.basic import ( + InvGammaRV, + LogNormalRV, + MvNormalRV, + NormalRV, +) from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.rewriting.numba import simplify_core_shape_graphs from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.variable import TensorConstant + + +_RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64 = frozenset( + (NormalRV, LogNormalRV, InvGammaRV, MvNormalRV) +) @node_rewriter([RandomVariable]) @@ -101,3 +115,73 @@ def introduce_explicit_core_shape_rv(fgraph, node): "numba", position=100, ) + + +@node_rewriter([RandomVariable]) +def cast_rv_float_params_to_float64(fgraph, node): + """Cast non-float64 floating-point RandomVariable parameters to float64. + + For a few distributions, numba's sampler runs markedly faster -- or only compiles + at all -- when its float parameters are float64 rather than float32 / integer. The + declared ``dtype`` (and ``floatX``) is only a final cast on the output, never the + sampling precision, so upcasting the parameters once here is a performance fix, not + a semantic one: the draw is still cast to the declared dtype on store, exactly as + ``RandomVariable.perform`` does (it hands parameters straight to NumPy and narrows + only the output via ``np.asarray(..., dtype=self.dtype)``). + + Which RVs actually benefit is not predictable from "the generator samples in + float64" -- it depends on the specific numba implementation and was measured per + distribution; see ``_RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64`` for the opt-in list. For + ``normal``, the ``int8`` literals of ``normal(0, 1)`` becoming float64 is a ~3x win. + + The target is float64, *not* the output dtype: under ``floatX="float32"`` casting to + the output dtype would leave the parameter float32 and still mismatched, so it would + not help. This is why the rewrite is Numba-specific -- JAX can sample natively in + float32, where forcing float64 parameters would pessimize. + """ + op = node.op + if type(op) not in _RVS_TO_CAST_FLOAT_PARAMS_TO_FLOAT64: + # Opt-in narrowing. The rewrite is registered on the ``RandomVariable`` base + # class (one cheap isinstance per graph node) and the exact-type set membership + # here only runs for the few actual RV nodes, bailing on all but the listed ones. + return None + + if config.warn_float64 != "ignore": + # The user asked to be warned about / forbidden float64; don't introduce it. + return None + + dist_params = op.dist_params(node) + # Every parameter of these (continuous) RVs is real-valued, so upcast any that is not + # already float64 -- ``normal(0, 1)``'s int8 literals included. + cast_idxs = [ + i for i, param in enumerate(dist_params) if param.type.dtype != "float64" + ] + if not cast_idxs: + return None + + new_params = list(dist_params) + for i in cast_idxs: + param = dist_params[i] + if isinstance(param, TensorConstant): + # Fold the cast into the constant so the common ``normal(0, 1)`` case ends + # up with float64 literals rather than a leftover Cast node. + new_params[i] = constant(param.data.astype("float64")) + else: + new_params[i] = cast(param, "float64") + + new_outputs = op.make_node( + op.rng_param(node), op.size_param(node), *new_params + ).outputs + copy_stack_trace(node.outputs, new_outputs) + return new_outputs + + +optdb.register( + cast_rv_float_params_to_float64.__name__, + dfs_rewriter(cast_rv_float_params_to_float64), + "numba", + # After stabilize (1.5), before specialize (2): a one-shot sweep that inserts the + # casts, leaving specialize's equilibrium (local_cast_cast + constant folding) to + # clean them up, and running before the core-shape wrap at position 100. + position=1.9, +) diff --git a/tests/benchmarks/test_random.py b/tests/benchmarks/test_random.py new file mode 100644 index 0000000000..02f305b06b --- /dev/null +++ b/tests/benchmarks/test_random.py @@ -0,0 +1,20 @@ +import numpy as np +import pytest + +from pytensor import config, function, shared + + +@pytest.mark.parametrize("floatX", ("float64", "float32")) +def test_normal_rv_benchmark_numba(floatX, benchmark): + # Drawing standard normals through numba. The ``0, 1`` literals are typed int8, and + # numba's rng.normal samples in float64; without cast_rv_float_params_to_float64 + # upcasting the parameters once, each of the >100k draws pays a per-element cast and + # the function runs ~3x slower. This benchmark tracks that hot path. + with config.change_flags(floatX=floatX): + rng = shared(np.random.default_rng(0)) + next_rng, draws = rng.normal(0, 1, size=(2160, 50)) + fn = function( + [], draws, updates={rng: next_rng}, mode="NUMBA", trust_input=True + ) + fn() # compile / warm up before timing + benchmark(fn) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 6dd5cdbbe7..9eca66a6d2 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -5,6 +5,7 @@ import pytest import scipy.stats as stats +import pytensor import pytensor.tensor as pt import pytensor.tensor.random.basic as ptr from pytensor import shared @@ -771,3 +772,62 @@ def rng_fn(self, rng, value, size=None): assert large_sample.shape == (1000,) np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2) np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2) + + +def _compiled_rv_dist_param_dtypes(out): + """Dtypes of the distribution parameters of the single RV in the compiled graph. + + The numba backend wraps RVs in a ``RandomVariableWithCoreShape`` whose inputs are + ``(core_shape, rng, size, *dist_params)``, so the parameters are ``inputs[3:]``. + """ + fn = function([], out, mode=numba_mode) + [node] = [ + n + for n in fn.maker.fgraph.toposort() + if isinstance(n.op, RandomVariableWithCoreShape) + ] + return [inp.type.dtype for inp in node.inputs[3:]] + + +@pytest.mark.parametrize("floatX", ["float64", "float32"]) +def test_rv_float_params_cast_to_float64(floatX): + # cast_rv_float_params_to_float64 upcasts the float params of opt-in RVs to float64 + # (always float64, not the output dtype -- float32 would stay mismatched) and leaves + # every other RV untouched. + with pytensor.config.change_flags(floatX=floatX): + rng = shared(np.random.default_rng(123)) + + # Opt-in: the int8 literals 0, 1 become float64 regardless of floatX. + assert _compiled_rv_dist_param_dtypes( + pt.random.normal(0, 1, size=(5,), rng=rng) + ) == ["float64", "float64"] + + # Opt-in, and required: numba's np.dot rejects a float32 covariance, so MvNormal + # only compiles once mean/cov are float64. + assert _compiled_rv_dist_param_dtypes( + pt.random.multivariate_normal( + np.zeros(3, dtype="float32"), np.eye(3, dtype="float32"), rng=rng + ) + ) == ["float64", "float64"] + + # Not opt-in (no measurable numba speedup): the float32 probability is left as-is, + # as is the integer count ``n``. + assert _compiled_rv_dist_param_dtypes( + pt.random.binomial(n=np.int64(10), p=np.float32(0.3), size=(5,), rng=rng) + ) == ["int64", "float32"] + + # Not opt-in and data-following: upcasting Permutation's array would change the + # output dtype and pointlessly widen the sampled data. + assert _compiled_rv_dist_param_dtypes( + pt.random.permutation(np.arange(5, dtype="float32"), rng=rng) + ) == ["float32"] + + +def test_rv_float_params_cast_respects_warn_float64(): + # When the user asked to be warned/raised on float64, the rewrite must not silently + # introduce it; the int8 params are left as-is. + with pytensor.config.change_flags(floatX="float32", warn_float64="raise"): + rng = shared(np.random.default_rng(123)) + assert _compiled_rv_dist_param_dtypes( + pt.random.normal(0, 1, size=(5,), rng=rng) + ) == ["int8", "int8"] diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 68315219a8..e1cb3b91c1 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -1016,9 +1016,12 @@ def test_sidestep_unused_rng_consumer_with_duplicate_node(): def test_unused_rng(): rng = random_generator_type("rng") - next_rng, x = rng.normal([0], [1], size=3) - next_rng, _y = next_rng.normal(x.ones_like(), [1]) - final_rng, z = next_rng.normal(1, 2) + # float64 parameters throughout, so the NUMBA float64-cast rewrite is a no-op and + # the strict graph comparison below need not account for inserted casts. Scalar + # Python floats autocast to float32, so use np.float64 for the scalar loc/scale. + next_rng, x = rng.normal([0.0], [1.0], size=3) + next_rng, _y = next_rng.normal(x.ones_like(), [1.0]) + final_rng, z = next_rng.normal(np.float64(1.0), np.float64(2.0)) fn = function([rng], [final_rng, z], mode=get_default_mode().excluding("inplace")) @@ -1050,9 +1053,9 @@ def test_unused_rng(): if config.mode != "FAST_COMPILE": # Strict graph comparison (ones_like gets constant-folded outside FAST_COMPILE) rng.tag.used = False # Avoid reuse warnings - next_rng, _x = rng.normal([0], [1], size=3) - next_rng, _y = next_rng.normal([1.0, 1.0, 1.0], [1]) - final_rng, z = next_rng.normal(1, 2) + next_rng, _x = rng.normal([0.0], [1.0], size=3) + next_rng, _y = next_rng.normal([1.0, 1.0, 1.0], [1.0]) + final_rng, z = next_rng.normal(np.float64(1.0), np.float64(2.0)) expected = [final_rng, z]