diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index ac59f1809c..fe675a0060 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -16,4 +16,5 @@ import pytensor.link.mlx.dispatch.sort import pytensor.link.mlx.dispatch.slinalg import pytensor.link.mlx.dispatch.nlinalg +import pytensor.link.mlx.dispatch.random # isort: on diff --git a/pytensor/link/mlx/dispatch/random.py b/pytensor/link/mlx/dispatch/random.py new file mode 100644 index 0000000000..032d5b1ce9 --- /dev/null +++ b/pytensor/link/mlx/dispatch/random.py @@ -0,0 +1,191 @@ +from functools import singledispatch + +import mlx.core as mx +from numpy.random import Generator + +import pytensor.tensor.random.basic as ptr +from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify +from pytensor.link.mlx.dispatch.tensor_basic import ( + convert_dtype_to_mlx, + mlx_to_list_shape, +) + + +def numpy_generator_to_mlx_key(rng: Generator) -> mx.array: + """Convert a NumPy Generator to an MLX random key. + + MLX uses a functional RNG model where each random call takes an explicit + key rather than mutating shared state. The PCG64 state is 128 bits, which + MLX cannot accept directly. We fold both 64-bit halves together via XOR + to use all 128 bits of entropy in a single 64-bit seed. + """ + state_128 = int(rng.bit_generator.state["state"]["state"]) + upper = (state_128 >> 64) & 0xFFFFFFFFFFFFFFFF + lower = state_128 & 0xFFFFFFFFFFFFFFFF + return mx.random.key(upper ^ lower) + + +@mlx_typify.register(Generator) +def mlx_typify_Generator(rng, **kwargs): + return numpy_generator_to_mlx_key(rng) + + +@mlx_funcify.register(ptr.RandomVariable) +def mlx_funcify_RandomVariable(op, node, **kwargs): + rv = node.outputs[1] + out_dtype = rv.type.dtype + + sample_fn_inner = mlx_sample_fn(op, node) + + def sample_fn(rng, size, *parameters): + new_keys = mx.random.split(rng, num=2) + new_rng = new_keys[0] + sampling_key = new_keys[1] + sample = sample_fn_inner(sampling_key, size, out_dtype, *parameters) + return (new_rng, sample) + + return sample_fn + + +@singledispatch +def mlx_sample_fn(op, node): + raise NotImplementedError( + f"No MLX implementation for the given distribution: {op.name}" + ) + + +@mlx_sample_fn.register(ptr.NormalRV) +def mlx_sample_fn_normal(op, node): + def sample_fn(rng_key, size, dtype, mu, sigma): + mlx_dtype = convert_dtype_to_mlx(dtype) + mu = mx.array(mu, dtype=mlx_dtype) + sigma = mx.array(sigma, dtype=mlx_dtype) + if size is None: + shape = mx.broadcast_arrays(mu, sigma)[0].shape + else: + shape = mlx_to_list_shape(size) + s = mx.random.normal(shape=shape, dtype=mlx_dtype, key=rng_key) + return mu + sigma * s + + return sample_fn + + +@mlx_sample_fn.register(ptr.UniformRV) +def mlx_sample_fn_uniform(op, node): + def sample_fn(rng_key, size, dtype, low, high): + mlx_dtype = convert_dtype_to_mlx(dtype) + low = mx.array(low, dtype=mlx_dtype) + high = mx.array(high, dtype=mlx_dtype) + if size is None: + shape = mx.broadcast_arrays(low, high)[0].shape + else: + shape = mlx_to_list_shape(size) + return mx.random.uniform( + low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key + ) + + return sample_fn + + +@mlx_sample_fn.register(ptr.BernoulliRV) +def mlx_sample_fn_bernoulli(op, node): + def sample_fn(rng_key, size, dtype, p): + p = mx.array(p) + if size is None: + shape = p.shape + else: + shape = mlx_to_list_shape(size) + return mx.random.bernoulli(p=p, shape=shape, key=rng_key) + + return sample_fn + + +@mlx_sample_fn.register(ptr.CategoricalRV) +def mlx_sample_fn_categorical(op, node): + def sample_fn(rng_key, size, dtype, p): + logits = mx.log(mx.array(p)) + shape = mlx_to_list_shape(size) if size is not None else None + return mx.random.categorical(logits=logits, axis=-1, shape=shape, key=rng_key) + + return sample_fn + + +@mlx_sample_fn.register(ptr.MvNormalRV) +def mlx_sample_fn_mvnormal(op, node): + def sample_fn(rng_key, size, dtype, mean, cov): + mlx_dtype = convert_dtype_to_mlx(dtype) + shape = mlx_to_list_shape(size) if size is not None else [] + # multivariate_normal uses SVD internally, which requires mx.cpu in MLX. + return mx.random.multivariate_normal( + mean=mean, + cov=cov, + shape=shape, + dtype=mlx_dtype, + key=rng_key, + stream=mx.cpu, + ) + + return sample_fn + + +@mlx_sample_fn.register(ptr.LaplaceRV) +def mlx_sample_fn_laplace(op, node): + def sample_fn(rng_key, size, dtype, loc, scale): + mlx_dtype = convert_dtype_to_mlx(dtype) + loc = mx.array(loc, dtype=mlx_dtype) + scale = mx.array(scale, dtype=mlx_dtype) + if size is None: + shape = mx.broadcast_arrays(loc, scale)[0].shape + else: + shape = mlx_to_list_shape(size) + s = mx.random.laplace(shape=shape, dtype=mlx_dtype, key=rng_key) + return loc + scale * s + + return sample_fn + + +@mlx_sample_fn.register(ptr.GumbelRV) +def mlx_sample_fn_gumbel(op, node): + def sample_fn(rng_key, size, dtype, loc, scale): + mlx_dtype = convert_dtype_to_mlx(dtype) + loc = mx.array(loc, dtype=mlx_dtype) + scale = mx.array(scale, dtype=mlx_dtype) + if size is None: + shape = mx.broadcast_arrays(loc, scale)[0].shape + else: + shape = mlx_to_list_shape(size) + s = mx.random.gumbel(shape=shape, dtype=mlx_dtype, key=rng_key) + return loc + scale * s + + return sample_fn + + +@mlx_sample_fn.register(ptr.PermutationRV) +def mlx_sample_fn_permutation(op, node): + batch_ndim = op.batch_ndim(node) + + def sample_fn(rng_key, size, dtype, x): + if batch_ndim: + raise NotImplementedError( + "MLX random.permutation does not support batch dimensions." + ) + return mx.random.permutation(x, key=rng_key) + + return sample_fn + + +@mlx_sample_fn.register(ptr.IntegersRV) +def mlx_sample_fn_integers(op, node): + def sample_fn(rng_key, size, dtype, low, high): + mlx_dtype = convert_dtype_to_mlx(dtype) + low = mx.array(low, dtype=mlx_dtype) + high = mx.array(high, dtype=mlx_dtype) + if size is None: + shape = mx.broadcast_arrays(low, high)[0].shape + else: + shape = mlx_to_list_shape(size) + return mx.random.randint( + low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key + ) + + return sample_fn diff --git a/pytensor/link/mlx/dispatch/tensor_basic.py b/pytensor/link/mlx/dispatch/tensor_basic.py index 730aa140c4..269f3d09ae 100644 --- a/pytensor/link/mlx/dispatch/tensor_basic.py +++ b/pytensor/link/mlx/dispatch/tensor_basic.py @@ -249,6 +249,16 @@ def _coerce_to_int(value): raise +def mlx_to_list_shape(size) -> list[int]: + """Convert a size value (mx.array, np.ndarray, or sequence) to a plain Python list of ints. + + Used by random variable dispatch to normalise the ``size`` argument, which + PyTensor may pass as an ``mx.array`` or ``np.ndarray`` rather than a plain + Python list. + """ + return [_coerce_to_int(x) for x in size] + + def _rethrow_dynamic_shape_error(exc): msg = str(exc) if "[eval] Attempting to eval an array during function transformations" in msg: diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index fea4c73d5c..c87a879551 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -1,3 +1,6 @@ +import warnings + +from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.link.basic import JITLinker @@ -17,7 +20,7 @@ def __init__(self, use_compile=True, *args, **kwargs): self.gen_functors = [] self.use_compile = use_compile - def fgraph_convert(self, fgraph, **kwargs): + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): """Convert a PyTensor FunctionGraph to an MLX-compatible function. Parameters @@ -31,9 +34,63 @@ def fgraph_convert(self, fgraph, **kwargs): An MLX-compatible function """ from pytensor.link.mlx.dispatch import mlx_funcify + from pytensor.tensor.random.type import RandomType + + shared_rng_inputs = [ + inp + for inp in fgraph.inputs + if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) + ] + + # Replace any shared RNG inputs so that their values can be updated in place + # without affecting the original RNG container. This is necessary because + # MLX does not accept Generators as inputs, and they will have to + # be typified + if shared_rng_inputs: + warnings.warn( + f"The RandomType SharedVariables {shared_rng_inputs} will not be used " + f"in the compiled MLX graph. Instead a copy will be used.", + UserWarning, + ) + new_shared_rng_inputs = [ + shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs + ] + + fgraph.replace_all( + zip(shared_rng_inputs, new_shared_rng_inputs, strict=True), + import_missing=True, + reason="MLXLinker.fgraph_convert", + ) + + for old_inp, new_inp in zip( + shared_rng_inputs, new_shared_rng_inputs, strict=True + ): + new_inp_storage = [new_inp.get_value(borrow=True)] + storage_map[new_inp] = new_inp_storage + old_inp_storage = storage_map.pop(old_inp) + # Find index of old_inp_storage in input_storage + for input_storage_idx, input_storage_item in enumerate(input_storage): + # We have to establish equality based on identity because input_storage may contain numpy arrays + if input_storage_item is old_inp_storage: + break + else: # no break + raise ValueError() + input_storage[input_storage_idx] = new_inp_storage + # We need to change the order of the inputs of the FunctionGraph + # so that the new input is in the same position as to old one, + # to align with the storage_map. We hope this is safe! + old_inp_fgraph_index = fgraph.inputs.index(old_inp) + fgraph.remove_input( + old_inp_fgraph_index, + reason="MLXLinker.fgraph_convert", + ) + fgraph.inputs.remove(new_inp) + fgraph.inputs.insert(old_inp_fgraph_index, new_inp) return mlx_funcify( fgraph, + input_storage=input_storage, + storage_map=storage_map, **kwargs, ) @@ -69,9 +126,16 @@ def create_thunk_inputs(self, storage_map): list The inputs for the thunk """ + from numpy.random import Generator + + from pytensor.link.mlx.dispatch import mlx_typify + thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] + if isinstance(sinput[0], Generator): + # Convert Generator into MLX PRNG key + sinput[0] = mlx_typify(sinput[0]) thunk_inputs.append(sinput) return thunk_inputs diff --git a/tests/link/mlx/test_random.py b/tests/link/mlx/test_random.py new file mode 100644 index 0000000000..61c524e625 --- /dev/null +++ b/tests/link/mlx/test_random.py @@ -0,0 +1,256 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.compile.function import function +from pytensor.compile.mode import MLX, Mode +from pytensor.compile.sharedvalue import shared +from pytensor.link.mlx.linker import MLXLinker +from pytensor.tensor.random.utils import RandomStream + + +mx = pytest.importorskip("mlx.core") + +# MLX mode without mx.compile — needed for ops that use CPU streams internally +# (e.g. multivariate_normal, which uses SVD via mx.cpu stream and is +# incompatible with mx.compile's tracing). +MLX_NO_COMPILE = Mode(linker=MLXLinker(use_compile=False), optimizer=MLX.optimizer) + + +def test_normal_cumsum(): + out = pt.random.normal(size=(52,)).cumsum() + result = out.eval(mode="MLX") + assert isinstance(result, mx.array) + assert result.shape == (52,) + + +def check_shape_and_dtype( + make_rv, expected_shape, expected_dtype=None, n_evals=2, mode="MLX" +): + """Compile and run an RV under MLX, assert shape and dtype, and verify + that two successive draws differ (RNG state is properly threaded). + + Parameters + ---------- + make_rv : callable(srng) -> rv_var + Factory that creates the RV using the provided RandomStream. + expected_shape : tuple + expected_dtype : str or None + n_evals : int + mode : str or Mode + """ + srng = RandomStream(seed=12345) + rv = make_rv(srng) + f = pytensor.function([], rv, mode=mode, updates=srng.updates()) + results = [np.array(f()) for _ in range(n_evals)] + + for r in results: + assert r.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {r.shape}" + ) + if expected_dtype is not None: + assert r.dtype == np.dtype(expected_dtype), ( + f"Expected dtype {expected_dtype}, got {r.dtype}" + ) + + assert not np.array_equal(results[0], results[1]), ( + "Two draws were identical — RNG not advancing" + ) + + return results + + +def test_normal_shape_dtype(): + check_shape_and_dtype( + lambda srng: srng.normal(loc=0.0, scale=1.0, size=(3, 4)), + (3, 4), + "float32", + ) + + +def test_normal_scalar(): + check_shape_and_dtype( + lambda srng: srng.normal(loc=2.0, scale=0.5), + (), + ) + + +def test_normal_array_params(): + result = pt.random.normal(loc=[0, 1], scale=[1.0, 0.3], size=(100, 2)).eval( + mode="MLX" + ) + assert result.shape == (100, 2) + means = np.array(result).mean(axis=0) + assert abs(means[0]) < 0.3 + assert abs(means[1] - 1.0) < 0.3 + + +def test_uniform_shape_dtype(): + results = check_shape_and_dtype( + lambda srng: srng.uniform(low=0.0, high=1.0, size=(10,)), + (10,), + "float32", + ) + r = np.array(results[0]) + assert np.all(r >= 0.0) + assert np.all(r < 1.0) + + +def test_bernoulli_shape(): + check_shape_and_dtype( + lambda srng: srng.bernoulli(p=0.7, size=(5, 5)), + (5, 5), + ) + + +def test_categorical_shape(): + probs = np.array([0.1, 0.4, 0.5], dtype=np.float32) + results = check_shape_and_dtype( + lambda srng: srng.categorical(p=probs, size=(8,)), + (8,), + ) + r = np.array(results[0]) + assert np.all(r < 3) + assert np.all(r >= 0) + + +def test_mvnormal_shape(): + mean = np.zeros(4, dtype=np.float32) + cov = np.eye(4, dtype=np.float32) + # multivariate_normal uses SVD internally (CPU-only in MLX), which is + # incompatible with mx.compile — use the no-compile mode. + check_shape_and_dtype( + lambda srng: srng.multivariate_normal(mean=mean, cov=cov, size=(6,)), + (6, 4), + "float32", + mode=MLX_NO_COMPILE, + ) + + +def test_laplace_shape_dtype(): + check_shape_and_dtype( + lambda srng: srng.laplace(loc=0.0, scale=1.0, size=(7,)), + (7,), + "float32", + ) + + +def test_gumbel_shape_dtype(): + check_shape_and_dtype( + lambda srng: srng.gumbel(loc=0.0, scale=1.0, size=(6,)), + (6,), + "float32", + ) + + +def test_integers_shape(): + results = check_shape_and_dtype( + lambda srng: srng.integers(low=0, high=10, size=(12,)), + (12,), + ) + r = np.array(results[0]) + assert np.all(r >= 0) + assert np.all(r < 10) + + +def test_permutation_shape(): + x = np.arange(8, dtype=np.int32) + results = check_shape_and_dtype( + lambda srng: srng.permutation(x), + (8,), + ) + assert sorted(np.array(results[0]).tolist()) == list(range(8)) + + +def test_gamma_not_implemented(): + srng = RandomStream(seed=1) + rv = srng.gamma(shape=1.0, scale=1.0, size=(3,)) + with pytest.raises(NotImplementedError, match="No MLX implementation"): + pytensor.function([], rv, mode="MLX", updates=srng.updates()) + + +def test_beta_not_implemented(): + srng = RandomStream(seed=1) + rv = srng.beta(alpha=2.0, beta=5.0, size=(3,)) + with pytest.raises(NotImplementedError, match="No MLX implementation"): + pytensor.function([], rv, mode="MLX", updates=srng.updates()) + + +def compile_shared_rng_function(*args, mode="MLX", **kwargs): + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + return function(*args, mode=mode, **kwargs) + + +def test_random_updates(): + original_value = np.random.default_rng(seed=98) + rng = shared(original_value, name="original_rng", borrow=False) + next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs + + f = compile_shared_rng_function([], [x], updates={rng: next_rng}) + assert f() != f() + + # Check that the original shared variable was not overwritten when typifying + assert all( + a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) + for a, b in zip( + rng.get_value().bit_generator.state, + original_value.bit_generator.state, + strict=True, + ) + ) + + +@pytest.mark.parametrize("noise_first", (False, True)) +def test_replaced_shared_rng_storage_order(noise_first): + # Test that replacing the RNG variable in the linker does not cause + # a disalignment between the compiled graph and the storage_map. + + mu = pytensor.shared(np.array(1.0), name="mu") + rng = pytensor.shared(np.random.default_rng(123)) + next_rng, noise = pt.random.normal(rng=rng).owner.outputs + + out = noise * mu if noise_first else mu * noise + + updates = { + mu: pt.grad(out, mu), + rng: next_rng, + } + f = compile_shared_rng_function([], [out], updates=updates) + + # Confirm that input_storage type and fgraph input order are aligned + for storage, fgraph_input in zip( + f.input_storage, f.maker.fgraph.inputs, strict=True + ): + assert storage.type == fgraph_input.type + + assert mu.get_value() == 1 + f() + assert mu.get_value() != 1 + + +def test_replaced_shared_rng_storage_ordering_equality(): + """Test that storage identity comparison works when numpy arrays precede + the RNG in input_storage (regression test for issue #314).""" + pt_rng = RandomStream(1) + + batchshape = (3, 1, 4, 4) + inp_shared = pytensor.shared( + np.zeros(batchshape, dtype="float64"), name="inp_shared" + ) + + inp = pt.tensor4(dtype="float64", name="inp") + inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5) + + fn = compile_shared_rng_function( + inputs=[], + outputs=[], + updates={inp_shared: inp_update}, + givens={inp: inp_shared}, + ) + fn() + np.testing.assert_allclose(np.array(inp_shared.get_value()), 5, rtol=1e-2) + fn() + np.testing.assert_allclose(np.array(inp_shared.get_value()), 10, rtol=1e-2)