Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.linalg
import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.reshape
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.shape
Expand Down
30 changes: 30 additions & 0 deletions pytensor/link/jax/dispatch/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.reshape import (
JoinDims,
SplitDims,
)


@jax_funcify.register(JoinDims)
def jax_funcify_JoinDims(op, node, **kwargs):
start_axis = op.start_axis
n_axes = op.n_axes

def join_dims(x):
output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :])
return jnp.reshape(x, output_shape)

return join_dims


@jax_funcify.register(SplitDims)
def jax_funcify_SplitDims(op, node, **kwargs):
axis = op.axis

def split_dims(x, shape):
output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :])
return jnp.reshape(x, output_shape)

return split_dims
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.linalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.reshape
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.shape
Expand Down
30 changes: 30 additions & 0 deletions pytensor/link/numba/dispatch/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.tensor.reshape import JoinDims, SplitDims


@register_funcify_default_op_cache_key(JoinDims)
def numba_funcify_JoinDims(op, node, **kwargs):
start_axis = op.start_axis
n_axes = op.n_axes

@numba_basic.numba_njit
def join_dims(x):
output_shape = (*x.shape[:start_axis], -1, *x.shape[start_axis + n_axes :])
return np.reshape(x, output_shape)

return join_dims


@register_funcify_default_op_cache_key(SplitDims)
def numba_funcify_SplitDims(op, node, **kwargs):
axis = op.axis

@numba_basic.numba_njit
def split_dims(x, shape):
output_shape = (*x.shape[:axis], *shape, *x.shape[axis + 1 :])
return np.reshape(x, output_shape)

return split_dims
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def reshape(x, shape):
@numba_basic.numba_njit
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
# the above issue is closed, what to do instead?
return np.reshape(
np.ascontiguousarray(np.asarray(x)),
numba_ndarray.to_fixed_tuple(shape, ndim),
Expand Down
5 changes: 4 additions & 1 deletion pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
None means all axes (like numpy).
"""
from pytensor.ifelse import ifelse
from pytensor.tensor.reshape import join_dims

x = as_tensor_variable(x)
x_ndim = x.type.ndim
Expand All @@ -2776,7 +2777,9 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
x_raveled = join_dims(
x_raveled, start_axis=len(non_axis_shape), n_axes=len(axis)
)
raveled_size = x_raveled.shape[-1]
k = raveled_size // 2

Expand Down
8 changes: 7 additions & 1 deletion pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def join_dims(
return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value]


class SplitDims(Op):
class SplitDims(COp):
__props__ = ("axis",)
view_map = {0: [0]}

Expand Down Expand Up @@ -217,6 +217,12 @@ def pullback(self, inputs, outputs, output_grads):
disconnected_type(),
]

def pushforward() # is this needed?

def c_code_cache_version(self):
return (10,)

def c_code():

@_vectorize_node.register(SplitDims)
def _vectorize_splitdims(op, node, x, shape):
Expand Down
11 changes: 11 additions & 0 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,14 @@ def specify_broadcastable(x, *axes):
axes = normalize_axis_tuple(axes, x.type.ndim)
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
return specify_shape(x, shape_info)


Class JoinDims(COp):

def __init__

def c_code_cache_version(self):
return (10,)

def c_code(self, node, name, inputs, outputs, sub):

Empty file added tests/link/jax/test_reshape.py
Empty file.
Empty file.
Loading