diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index cfa9fcd848..67da601690 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -11,6 +11,7 @@ import pytensor.link.jax.dispatch.math import pytensor.link.jax.dispatch.linalg import pytensor.link.jax.dispatch.random +import pytensor.link.jax.dispatch.reshape import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.shape diff --git a/pytensor/link/jax/dispatch/reshape.py b/pytensor/link/jax/dispatch/reshape.py new file mode 100644 index 0000000000..740e2f9b63 --- /dev/null +++ b/pytensor/link/jax/dispatch/reshape.py @@ -0,0 +1,30 @@ +import jax.numpy as jnp + +from pytensor.link.jax.dispatch.basic import jax_funcify +from pytensor.tensor.reshape import ( + JoinDims, + SplitDims, +) + + +@jax_funcify.register(JoinDims) +def jax_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + + def join_dims(x): + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return jnp.reshape(x, output_shape) + + return join_dims + + +@jax_funcify.register(SplitDims) +def jax_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return jnp.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index c86c65aa7d..6bc31496f5 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -8,6 +8,7 @@ import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.linalg import pytensor.link.numba.dispatch.random +import pytensor.link.numba.dispatch.reshape import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.shape diff --git a/pytensor/link/numba/dispatch/reshape.py b/pytensor/link/numba/dispatch/reshape.py new file mode 100644 index 0000000000..2c43c9cbe5 --- /dev/null +++ b/pytensor/link/numba/dispatch/reshape.py @@ -0,0 +1,30 @@ +import numpy as np + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.tensor.reshape import JoinDims, SplitDims + + +@register_funcify_default_op_cache_key(JoinDims) +def numba_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + + @numba_basic.numba_njit + def join_dims(x): + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return np.reshape(x, output_shape) + + return join_dims + + +@register_funcify_default_op_cache_key(SplitDims) +def numba_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + @numba_basic.numba_njit + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return np.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index b6a5533809..4cbbe24754 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -70,6 +70,7 @@ def reshape(x, shape): @numba_basic.numba_njit def reshape(x, shape): # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + # the above issue is closed, what to do instead? return np.reshape( np.ascontiguousarray(np.asarray(x)), numba_ndarray.to_fixed_tuple(shape, ndim), diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..6b82061f22 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2762,6 +2762,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable: None means all axes (like numpy). """ from pytensor.ifelse import ifelse + from pytensor.tensor.reshape import join_dims x = as_tensor_variable(x) x_ndim = x.type.ndim @@ -2776,7 +2777,9 @@ def median(x: TensorLike, axis=None) -> TensorVariable: # Put axis at the end and unravel them x_raveled = x.transpose(*non_axis, *axis) if len(axis) > 1: - x_raveled = x_raveled.reshape((*non_axis_shape, -1)) + x_raveled = join_dims( + x_raveled, start_axis=len(non_axis_shape), n_axes=len(axis) + ) raveled_size = x_raveled.shape[-1] k = raveled_size // 2 diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index b25308fccc..57adfe1e99 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -152,7 +152,7 @@ def join_dims( return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value] -class SplitDims(Op): +class SplitDims(COp): __props__ = ("axis",) view_map = {0: [0]} @@ -217,6 +217,12 @@ def pullback(self, inputs, outputs, output_grads): disconnected_type(), ] + def pushforward() # is this needed? + + def c_code_cache_version(self): + return (10,) + + def c_code(): @_vectorize_node.register(SplitDims) def _vectorize_splitdims(op, node, x, shape): diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..c835dbd695 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1017,3 +1017,14 @@ def specify_broadcastable(x, *axes): axes = normalize_axis_tuple(axes, x.type.ndim) shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)] return specify_shape(x, shape_info) + + +Class JoinDims(COp): + + def __init__ + + def c_code_cache_version(self): + return (10,) + + def c_code(self, node, name, inputs, outputs, sub): + diff --git a/tests/link/jax/test_reshape.py b/tests/link/jax/test_reshape.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/numba/test_reshape.py b/tests/link/numba/test_reshape.py new file mode 100644 index 0000000000..e69de29bb2