From 6444df9aa9a1f0e0315af30dc45839e8f0e7df9c Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Fri, 17 Apr 2026 09:43:58 -0400 Subject: [PATCH 1/4] starting a new branch for proposing internal use of join/split dims instead of reshape to reduce rounds of rewriting --- pytensor/tensor/math.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..6b82061f22 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2762,6 +2762,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable: None means all axes (like numpy). """ from pytensor.ifelse import ifelse + from pytensor.tensor.reshape import join_dims x = as_tensor_variable(x) x_ndim = x.type.ndim @@ -2776,7 +2777,9 @@ def median(x: TensorLike, axis=None) -> TensorVariable: # Put axis at the end and unravel them x_raveled = x.transpose(*non_axis, *axis) if len(axis) > 1: - x_raveled = x_raveled.reshape((*non_axis_shape, -1)) + x_raveled = join_dims( + x_raveled, start_axis=len(non_axis_shape), n_axes=len(axis) + ) raveled_size = x_raveled.shape[-1] k = raveled_size // 2 From cc652c38f38975f1e84b3c6dd0497b3ff1a829ea Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Mon, 20 Apr 2026 11:02:30 -0400 Subject: [PATCH 2/4] attempt to add joindims to shape and dispatch --- pytensor/link/jax/dispatch/shape.py | 15 ++++++++++++++- pytensor/link/numba/dispatch/shape.py | 16 +++++++++++++++- pytensor/tensor/shape.py | 11 +++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index d7c1d0bcbd..331bec8060 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -4,7 +4,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape # JoinDims, SplitDims from pytensor.tensor.type import TensorType @@ -104,3 +104,16 @@ def specifyshape(x, *shape): return x return specifyshape + + +@jax_funcify.register(JoinDims) +def jax_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + def join_dims(x): + if n_axes == 0: + expanded_x = jnp.expand_dims(x, axis=start_axis) + return expanded_x + if n_axes == 1: + return x + return join_dims \ No newline at end of file diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index b6a5533809..e47cfd9cbe 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -7,7 +7,7 @@ from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key from pytensor.link.utils import compile_function_src from pytensor.tensor import NoneConst -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape # JoinDims, SplitDims @register_funcify_default_op_cache_key(Shape) @@ -70,9 +70,23 @@ def reshape(x, shape): @numba_basic.numba_njit def reshape(x, shape): # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + # the above issue is closed, what to do instead? return np.reshape( np.ascontiguousarray(np.asarray(x)), numba_ndarray.to_fixed_tuple(shape, ndim), ) return reshape + +@register_funcify_default_op_cache_key(JoinDims) +def numba_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + @numba_basic.numba_njit + def join_dims(x): + if n_axes == 0: + expanded_x = np.expand_dims(x, axis=start_axis) + return expanded_x + if n_axes == 1: + return x + return join_dims \ No newline at end of file diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..c835dbd695 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1017,3 +1017,14 @@ def specify_broadcastable(x, *axes): axes = normalize_axis_tuple(axes, x.type.ndim) shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)] return specify_shape(x, shape_info) + + +Class JoinDims(COp): + + def __init__ + + def c_code_cache_version(self): + return (10,) + + def c_code(self, node, name, inputs, outputs, sub): + From e99798674028ea0183c19a541735f11d1761c96f Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Tue, 21 Apr 2026 11:36:37 -0400 Subject: [PATCH 3/4] revised progress on joindim splitdim backend implementation --- pytensor/link/jax/dispatch/shape.py | 28 ++++++++++++++++++++------- pytensor/link/numba/dispatch/shape.py | 27 +++++++++++++++++++------- pytensor/tensor/reshape.py | 8 +++++++- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index 331bec8060..6c4656fbc4 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -4,7 +4,11 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape # JoinDims, SplitDims +from pytensor.tensor.reshape import ( # is the placement of joindims/splitdims under reshape instead of shape intentional? + JoinDims, + SplitDims, +) +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.type import TensorType @@ -110,10 +114,20 @@ def specifyshape(x, *shape): def jax_funcify_JoinDims(op, node, **kwargs): start_axis = op.start_axis n_axes = op.n_axes + def join_dims(x): - if n_axes == 0: - expanded_x = jnp.expand_dims(x, axis=start_axis) - return expanded_x - if n_axes == 1: - return x - return join_dims \ No newline at end of file + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return jnp.reshape(x, output_shape) + + return join_dims + + +@jax_funcify.register(SplitDims) +def jax_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return jnp.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index e47cfd9cbe..d6a1753a19 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -7,7 +7,8 @@ from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key from pytensor.link.utils import compile_function_src from pytensor.tensor import NoneConst -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape # JoinDims, SplitDims +from pytensor.tensor.reshape import JoinDims, SplitDims +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @register_funcify_default_op_cache_key(Shape) @@ -78,15 +79,27 @@ def reshape(x, shape): return reshape + @register_funcify_default_op_cache_key(JoinDims) def numba_funcify_JoinDims(op, node, **kwargs): start_axis = op.start_axis n_axes = op.n_axes + @numba_basic.numba_njit def join_dims(x): - if n_axes == 0: - expanded_x = np.expand_dims(x, axis=start_axis) - return expanded_x - if n_axes == 1: - return x - return join_dims \ No newline at end of file + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return np.reshape(x, output_shape) + + return join_dims + + +@register_funcify_default_op_cache_key(SplitDims) +def numba_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + @numba_basic.numba_njit + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return np.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index b25308fccc..57adfe1e99 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -152,7 +152,7 @@ def join_dims( return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value] -class SplitDims(Op): +class SplitDims(COp): __props__ = ("axis",) view_map = {0: [0]} @@ -217,6 +217,12 @@ def pullback(self, inputs, outputs, output_grads): disconnected_type(), ] + def pushforward() # is this needed? + + def c_code_cache_version(self): + return (10,) + + def c_code(): @_vectorize_node.register(SplitDims) def _vectorize_splitdims(op, node, x, shape): From 70714866b187543db50ba70c817a179d63a8b593 Mon Sep 17 00:00:00 2001 From: mbaldourw Date: Wed, 22 Apr 2026 09:36:40 -0400 Subject: [PATCH 4/4] creating mirrored structure --- pytensor/link/jax/dispatch/__init__.py | 1 + pytensor/link/jax/dispatch/reshape.py | 30 ++++++++++++++++++++++++ pytensor/link/jax/dispatch/shape.py | 27 --------------------- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/reshape.py | 30 ++++++++++++++++++++++++ pytensor/link/numba/dispatch/shape.py | 26 -------------------- tests/link/jax/test_reshape.py | 0 tests/link/numba/test_reshape.py | 0 8 files changed, 62 insertions(+), 53 deletions(-) create mode 100644 pytensor/link/jax/dispatch/reshape.py create mode 100644 pytensor/link/numba/dispatch/reshape.py create mode 100644 tests/link/jax/test_reshape.py create mode 100644 tests/link/numba/test_reshape.py diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index cfa9fcd848..67da601690 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -11,6 +11,7 @@ import pytensor.link.jax.dispatch.math import pytensor.link.jax.dispatch.linalg import pytensor.link.jax.dispatch.random +import pytensor.link.jax.dispatch.reshape import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.shape diff --git a/pytensor/link/jax/dispatch/reshape.py b/pytensor/link/jax/dispatch/reshape.py new file mode 100644 index 0000000000..740e2f9b63 --- /dev/null +++ b/pytensor/link/jax/dispatch/reshape.py @@ -0,0 +1,30 @@ +import jax.numpy as jnp + +from pytensor.link.jax.dispatch.basic import jax_funcify +from pytensor.tensor.reshape import ( + JoinDims, + SplitDims, +) + + +@jax_funcify.register(JoinDims) +def jax_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + + def join_dims(x): + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return jnp.reshape(x, output_shape) + + return join_dims + + +@jax_funcify.register(SplitDims) +def jax_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return jnp.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index 6c4656fbc4..d7c1d0bcbd 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -4,10 +4,6 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.reshape import ( # is the placement of joindims/splitdims under reshape instead of shape intentional? - JoinDims, - SplitDims, -) from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.type import TensorType @@ -108,26 +104,3 @@ def specifyshape(x, *shape): return x return specifyshape - - -@jax_funcify.register(JoinDims) -def jax_funcify_JoinDims(op, node, **kwargs): - start_axis = op.start_axis - n_axes = op.n_axes - - def join_dims(x): - output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) - return jnp.reshape(x, output_shape) - - return join_dims - - -@jax_funcify.register(SplitDims) -def jax_funcify_SplitDims(op, node, **kwargs): - axis = op.axis - - def split_dims(x, shape): - output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) - return jnp.reshape(x, output_shape) - - return split_dims diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index c86c65aa7d..6bc31496f5 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -8,6 +8,7 @@ import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.linalg import pytensor.link.numba.dispatch.random +import pytensor.link.numba.dispatch.reshape import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.shape diff --git a/pytensor/link/numba/dispatch/reshape.py b/pytensor/link/numba/dispatch/reshape.py new file mode 100644 index 0000000000..2c43c9cbe5 --- /dev/null +++ b/pytensor/link/numba/dispatch/reshape.py @@ -0,0 +1,30 @@ +import numpy as np + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.tensor.reshape import JoinDims, SplitDims + + +@register_funcify_default_op_cache_key(JoinDims) +def numba_funcify_JoinDims(op, node, **kwargs): + start_axis = op.start_axis + n_axes = op.n_axes + + @numba_basic.numba_njit + def join_dims(x): + output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) + return np.reshape(x, output_shape) + + return join_dims + + +@register_funcify_default_op_cache_key(SplitDims) +def numba_funcify_SplitDims(op, node, **kwargs): + axis = op.axis + + @numba_basic.numba_njit + def split_dims(x, shape): + output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) + return np.reshape(x, output_shape) + + return split_dims diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index d6a1753a19..4cbbe24754 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -7,7 +7,6 @@ from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key from pytensor.link.utils import compile_function_src from pytensor.tensor import NoneConst -from pytensor.tensor.reshape import JoinDims, SplitDims from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -78,28 +77,3 @@ def reshape(x, shape): ) return reshape - - -@register_funcify_default_op_cache_key(JoinDims) -def numba_funcify_JoinDims(op, node, **kwargs): - start_axis = op.start_axis - n_axes = op.n_axes - - @numba_basic.numba_njit - def join_dims(x): - output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :]) - return np.reshape(x, output_shape) - - return join_dims - - -@register_funcify_default_op_cache_key(SplitDims) -def numba_funcify_SplitDims(op, node, **kwargs): - axis = op.axis - - @numba_basic.numba_njit - def split_dims(x, shape): - output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :]) - return np.reshape(x, output_shape) - - return split_dims diff --git a/tests/link/jax/test_reshape.py b/tests/link/jax/test_reshape.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/numba/test_reshape.py b/tests/link/numba/test_reshape.py new file mode 100644 index 0000000000..e69de29bb2