Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytensor.link.numba.dispatch.compile_ops
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.join_inplace
import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.scan
Expand Down
49 changes: 49 additions & 0 deletions pytensor/link/numba/dispatch/join_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np

from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.tensor.rewriting.join_inplace import WriteJoin, WriteSplit


@register_funcify_default_op_cache_key(WriteSplit)
def numba_funcify_WriteSplit(op, node, **kwargs):
n_splits = op.n_splits
axis = op.axis

slice_lines = []
offset_expr = "0"
for i in range(n_splits):
slice_lines.append(f" sz_{i} = s{i}.item()")
idx = ", ".join(
f"{offset_expr}:{offset_expr} + sz_{i}" if d == axis else ":"
for d in range(node.inputs[0].type.ndim)
)
slice_lines.append(f" out_{i} = buffer[{idx}]")
offset_expr = f"{offset_expr} + sz_{i}"

return_vars = ", ".join(f"out_{i}" for i in range(n_splits))
size_params = ", ".join(f"s{i}" for i in range(n_splits))

func_src = f"""
def write_split(buffer, {size_params}):
{chr(10).join(slice_lines)}
return ({return_vars},)
"""
fn = compile_numba_function_src(func_src, "write_split", {"np": np})
return numba_basic.numba_njit(fn)


@register_funcify_default_op_cache_key(WriteJoin)
def numba_funcify_WriteJoin(op, node, **kwargs):
n_deps = len(node.inputs) - 1

dep_params = ", ".join(f"dep{i}" for i in range(n_deps))
sig = f"buffer, {dep_params}" if dep_params else "buffer"

func_src = f"""
def write_join({sig}):
return buffer
"""
fn = compile_numba_function_src(func_src, "write_join")
return numba_basic.numba_njit(fn)
7 changes: 6 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,12 @@ def _get_underlying_scalar_constant_value(
for i in v.owner.inputs
]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
try:
v.owner.op.perform(v.owner, const, ret)
except Exception:
# Elemwise.perform may not work in Python mode
# (e.g. fused Composites with >32 operands)
raise NotScalarConstantError(v)
return np.asarray(ret[0][0].copy())
elif isinstance(op, Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant):
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.join_inplace
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba
Expand Down
21 changes: 19 additions & 2 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,18 +956,35 @@ def local_useless_composite_outputs(fgraph, node):
comp_fgraph = FunctionGraph(
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
)
# Inputs that are inplace targets must be kept even if unused in the scalar graph
destroyed_input_idxs = set()
for in_idxs in node.op.inplace_pattern.values():
if isinstance(in_idxs, int):
destroyed_input_idxs.add(in_idxs)
else:
destroyed_input_idxs.update(in_idxs)

used_inputs_idxs = [
i
for i, i_intern in enumerate(comp_fgraph.inputs)
if comp_fgraph.clients[i_intern]
if comp_fgraph.clients[i_intern] or i in destroyed_input_idxs
]
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
node.outputs
):
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
# Remap inplace_pattern indices to the new input positions
old_to_new = {old: new for new, old in enumerate(used_inputs_idxs)}
new_inplace_pattern = {
out_idx: old_to_new[in_idx]
for out_idx, in_idx in node.op.inplace_pattern.items()
if in_idx in old_to_new
}
c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
e = Elemwise(scalar_op=c, inplace_pattern=new_inplace_pattern)(
*used_inputs, return_list=True
)
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))


Expand Down
Loading
Loading