Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
import pytensor.link.mlx.dispatch.sort
import pytensor.link.mlx.dispatch.slinalg
import pytensor.link.mlx.dispatch.nlinalg
import pytensor.link.mlx.dispatch.random
# isort: on
191 changes: 191 additions & 0 deletions pytensor/link/mlx/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from functools import singledispatch

import mlx.core as mx
from numpy.random import Generator

import pytensor.tensor.random.basic as ptr
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
from pytensor.link.mlx.dispatch.tensor_basic import (
convert_dtype_to_mlx,
mlx_to_list_shape,
)


def numpy_generator_to_mlx_key(rng: Generator) -> mx.array:
"""Convert a NumPy Generator to an MLX random key.

MLX uses a functional RNG model where each random call takes an explicit
key rather than mutating shared state. The PCG64 state is 128 bits, which
MLX cannot accept directly. We fold both 64-bit halves together via XOR
to use all 128 bits of entropy in a single 64-bit seed.
"""
state_128 = int(rng.bit_generator.state["state"]["state"])
upper = (state_128 >> 64) & 0xFFFFFFFFFFFFFFFF
lower = state_128 & 0xFFFFFFFFFFFFFFFF
return mx.random.key(upper ^ lower)


@mlx_typify.register(Generator)
def mlx_typify_Generator(rng, **kwargs):
return numpy_generator_to_mlx_key(rng)


@mlx_funcify.register(ptr.RandomVariable)
def mlx_funcify_RandomVariable(op, node, **kwargs):
rv = node.outputs[1]
out_dtype = rv.type.dtype

sample_fn_inner = mlx_sample_fn(op, node)

def sample_fn(rng, size, *parameters):
new_keys = mx.random.split(rng, num=2)
new_rng = new_keys[0]
sampling_key = new_keys[1]
sample = sample_fn_inner(sampling_key, size, out_dtype, *parameters)
return (new_rng, sample)

return sample_fn


@singledispatch
def mlx_sample_fn(op, node):
raise NotImplementedError(
f"No MLX implementation for the given distribution: {op.name}"
)


@mlx_sample_fn.register(ptr.NormalRV)
def mlx_sample_fn_normal(op, node):
def sample_fn(rng_key, size, dtype, mu, sigma):
mlx_dtype = convert_dtype_to_mlx(dtype)
mu = mx.array(mu, dtype=mlx_dtype)
sigma = mx.array(sigma, dtype=mlx_dtype)
if size is None:
shape = mx.broadcast_arrays(mu, sigma)[0].shape
else:
shape = mlx_to_list_shape(size)
s = mx.random.normal(shape=shape, dtype=mlx_dtype, key=rng_key)
return mu + sigma * s

return sample_fn


@mlx_sample_fn.register(ptr.UniformRV)
def mlx_sample_fn_uniform(op, node):
def sample_fn(rng_key, size, dtype, low, high):
mlx_dtype = convert_dtype_to_mlx(dtype)
low = mx.array(low, dtype=mlx_dtype)
high = mx.array(high, dtype=mlx_dtype)
if size is None:
shape = mx.broadcast_arrays(low, high)[0].shape
else:
shape = mlx_to_list_shape(size)
return mx.random.uniform(
low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key
)

return sample_fn


@mlx_sample_fn.register(ptr.BernoulliRV)
def mlx_sample_fn_bernoulli(op, node):
def sample_fn(rng_key, size, dtype, p):
p = mx.array(p)
if size is None:
shape = p.shape
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you always need the shape? You didn't need it in the categorical. I would assume you only need when one of the parameters doesn't go in the random function. If so that would take a lot of boilerplate away from your dispatches

else:
shape = mlx_to_list_shape(size)
return mx.random.bernoulli(p=p, shape=shape, key=rng_key)

return sample_fn


@mlx_sample_fn.register(ptr.CategoricalRV)
def mlx_sample_fn_categorical(op, node):
def sample_fn(rng_key, size, dtype, p):
logits = mx.log(mx.array(p))
shape = mlx_to_list_shape(size) if size is not None else None
return mx.random.categorical(logits=logits, axis=-1, shape=shape, key=rng_key)

return sample_fn


@mlx_sample_fn.register(ptr.MvNormalRV)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MvNormal supports different decomposition strategies, you may want to implement like numba dispatch/op.perform which is more low level if mx.random.multivariate_normal doesn't support them. Or if it's unfeasible issue a warning that it isn't respected and will fallback to svd (if it wasn't svd to begin with)

def mlx_sample_fn_mvnormal(op, node):
def sample_fn(rng_key, size, dtype, mean, cov):
mlx_dtype = convert_dtype_to_mlx(dtype)
shape = mlx_to_list_shape(size) if size is not None else []
# multivariate_normal uses SVD internally, which requires mx.cpu in MLX.
return mx.random.multivariate_normal(
mean=mean,
cov=cov,
shape=shape,
dtype=mlx_dtype,
key=rng_key,
stream=mx.cpu,
)

return sample_fn


@mlx_sample_fn.register(ptr.LaplaceRV)
def mlx_sample_fn_laplace(op, node):
def sample_fn(rng_key, size, dtype, loc, scale):
mlx_dtype = convert_dtype_to_mlx(dtype)
loc = mx.array(loc, dtype=mlx_dtype)
scale = mx.array(scale, dtype=mlx_dtype)
if size is None:
shape = mx.broadcast_arrays(loc, scale)[0].shape
else:
shape = mlx_to_list_shape(size)
s = mx.random.laplace(shape=shape, dtype=mlx_dtype, key=rng_key)
return loc + scale * s

return sample_fn


@mlx_sample_fn.register(ptr.GumbelRV)
def mlx_sample_fn_gumbel(op, node):
def sample_fn(rng_key, size, dtype, loc, scale):
mlx_dtype = convert_dtype_to_mlx(dtype)
loc = mx.array(loc, dtype=mlx_dtype)
scale = mx.array(scale, dtype=mlx_dtype)
if size is None:
shape = mx.broadcast_arrays(loc, scale)[0].shape
else:
shape = mlx_to_list_shape(size)
s = mx.random.gumbel(shape=shape, dtype=mlx_dtype, key=rng_key)
return loc + scale * s

return sample_fn


@mlx_sample_fn.register(ptr.PermutationRV)
def mlx_sample_fn_permutation(op, node):
batch_ndim = op.batch_ndim(node)

def sample_fn(rng_key, size, dtype, x):
if batch_ndim:
raise NotImplementedError(
"MLX random.permutation does not support batch dimensions."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise at dispatch time already

return mx.random.permutation(x, key=rng_key)

return sample_fn


@mlx_sample_fn.register(ptr.IntegersRV)
def mlx_sample_fn_integers(op, node):
def sample_fn(rng_key, size, dtype, low, high):
mlx_dtype = convert_dtype_to_mlx(dtype)
low = mx.array(low, dtype=mlx_dtype)
high = mx.array(high, dtype=mlx_dtype)
if size is None:
shape = mx.broadcast_arrays(low, high)[0].shape
else:
shape = mlx_to_list_shape(size)
return mx.random.randint(
low=low, high=high, shape=shape, dtype=mlx_dtype, key=rng_key
)

return sample_fn
10 changes: 10 additions & 0 deletions pytensor/link/mlx/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ def _coerce_to_int(value):
raise


def mlx_to_list_shape(size) -> list[int]:
"""Convert a size value (mx.array, np.ndarray, or sequence) to a plain Python list of ints.

Used by random variable dispatch to normalise the ``size`` argument, which
PyTensor may pass as an ``mx.array`` or ``np.ndarray`` rather than a plain
Python list.
"""
return [_coerce_to_int(x) for x in size]


def _rethrow_dynamic_shape_error(exc):
msg = str(exc)
if "[eval] Attempting to eval an array during function transformations" in msg:
Expand Down
66 changes: 65 additions & 1 deletion pytensor/link/mlx/linker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import warnings

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.link.basic import JITLinker


Expand All @@ -17,7 +20,7 @@ def __init__(self, use_compile=True, *args, **kwargs):
self.gen_functors = []
self.use_compile = use_compile

def fgraph_convert(self, fgraph, **kwargs):
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.

Parameters
Expand All @@ -31,9 +34,63 @@ def fgraph_convert(self, fgraph, **kwargs):
An MLX-compatible function
"""
from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.random.type import RandomType

shared_rng_inputs = [
inp
for inp in fgraph.inputs
if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType))
]

# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# MLX does not accept Generators as inputs, and they will have to
# be typified
if shared_rng_inputs:
warnings.warn(
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
f"in the compiled MLX graph. Instead a copy will be used.",
UserWarning,
)
new_shared_rng_inputs = [
shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs
]

fgraph.replace_all(
zip(shared_rng_inputs, new_shared_rng_inputs, strict=True),
import_missing=True,
reason="MLXLinker.fgraph_convert",
)

for old_inp, new_inp in zip(
shared_rng_inputs, new_shared_rng_inputs, strict=True
):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
# Find index of old_inp_storage in input_storage
for input_storage_idx, input_storage_item in enumerate(input_storage):
# We have to establish equality based on identity because input_storage may contain numpy arrays
if input_storage_item is old_inp_storage:
break
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
# We need to change the order of the inputs of the FunctionGraph
# so that the new input is in the same position as to old one,
# to align with the storage_map. We hope this is safe!
old_inp_fgraph_index = fgraph.inputs.index(old_inp)
fgraph.remove_input(
old_inp_fgraph_index,
reason="MLXLinker.fgraph_convert",
)
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgraph_index, new_inp)

return mlx_funcify(
fgraph,
input_storage=input_storage,
storage_map=storage_map,
**kwargs,
)

Expand Down Expand Up @@ -69,9 +126,16 @@ def create_thunk_inputs(self, storage_map):
list
The inputs for the thunk
"""
from numpy.random import Generator

from pytensor.link.mlx.dispatch import mlx_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], Generator):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to do the same dance jax linker does with shared Generator variables

# Convert Generator into MLX PRNG key
sinput[0] = mlx_typify(sinput[0])
thunk_inputs.append(sinput)

return thunk_inputs
Loading
Loading