From aad541bf47c42b5fd066a1ff46ec69758c8fd57b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 15:50:23 +0200 Subject: [PATCH 1/4] Strip Elemwise inplace from fused indexed-write outputs The fused loop writes the elementwise result to the write buffer, never to the inplaced input, but the inner fgraph kept the inplace Elemwise: the Python-mode fallback (OpFromGraph.perform) would destroy that input without the outer destroy map declaring it, losing the ordering constraint for other readers of the destroyed buffer. The JIT path was unaffected (write buffers shadow the inplace pattern in make_outputs). Write-and-direct duplication now runs before the strip and preserves the inplace pattern, so an inplace on an output that stays materialized (the write consuming a duplicate) still survives the fusion. --- pytensor/tensor/rewriting/indexed_elemwise.py | 79 ++++++++++++------- tests/link/numba/test_indexed_elemwise.py | 66 ++++++++++++++++ 2 files changed, 118 insertions(+), 27 deletions(-) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..97ef064fcc 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -399,7 +399,11 @@ def _duplicate_multi_client_outputs(node, multi_client_outs): s_outputs.append(s_outputs[out_idx]) new_scalar_op = Composite(s_inputs, s_outputs) - new_node = Elemwise(new_scalar_op).make_node(*node.inputs) + # Duplicates are appended after the original outputs, so the node's + # inplace pattern carries over unchanged (duplicates get no entries). + new_node = Elemwise(new_scalar_op, dict(node.op.inplace_pattern)).make_node( + *node.inputs + ) return new_node, dup_map @staticmethod @@ -589,35 +593,13 @@ def apply(self, fgraph): worklist.append(node) continue - indexed_reads = {i for reads, _ in idx_groups.values() for i in reads} - - # If any inplace targets an indexed-read input, - # strip and re-run inplace with those inputs protected - if any( - inp_idx in indexed_reads for inp_idx in node.op.inplace_pattern.values() - ): - stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) - fgraph.replace_all( - zip(node.outputs, stripped_node.outputs), - reason="fuse_indexed_elemwise_strip_inplace", - ) - protected = frozenset(stripped_node.inputs[i] for i in indexed_reads) - # try_inplace_on_node does its own fgraph.replace_all internally, - # so the returned node is already in the fgraph - new_inplace_node = InplaceElemwiseOptimizer().try_inplace_on_node( - fgraph, - stripped_node, - reason="fuse_indexed_elemwise_inplace_read_buffers", - extra_protected_inputs=protected, - ) - worklist.append(new_inplace_node) - continue - # If any indexed-write output also has other consumers, # duplicate it via Composite so the write replaces the duplicate # while the original stays available for non-write consumers. # We still avoid one extra write loop, - # even if we can't skip the output materialization altogether + # even if we can't skip the output materialization altogether. + # This runs before the strip-inplace pass below, so an inplace on the + # materialized original survives the fusion. def _has_non_write_clients(out_idx): update = write_targets[out_idx] for c, _ in fgraph.clients[node.outputs[out_idx]]: @@ -658,12 +640,55 @@ def _has_non_write_clients(out_idx): worklist.append(new_node) continue + indexed_reads = {i for reads, _ in idx_groups.values() for i in reads} + + # If any inplace targets an indexed-read input, or claims an indexed-write + # output (the loop writes the result to the write buffer instead, so the + # input destruction would happen only in the Python-mode fallback, + # undeclared by the outer destroy map), strip and re-run inplace with + # those inputs protected and outputs excluded. The duplication above ran + # first, so write-target outputs here are sole-client. + if any( + inp_idx in indexed_reads for inp_idx in node.op.inplace_pattern.values() + ) or any(out_idx in write_targets for out_idx in node.op.inplace_pattern): + stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) + fgraph.replace_all( + zip(node.outputs, stripped_node.outputs), + reason="fuse_indexed_elemwise_strip_inplace", + ) + optimizer = InplaceElemwiseOptimizer() + protected = optimizer._get_protected_inputs(fgraph) + protected.update(stripped_node.inputs[i] for i in indexed_reads) + # Candidates are plain-Elemwise (output, input) pairs; exclude + # outputs the fusion is about to consume as indexed writes + candidate_pairs = [ + pair + for pair in optimizer.filter_candidate_pairs( + fgraph, stripped_node, protected + ) + if pair[0][0] not in write_targets + ] + # try_inplace_on_node does its own fgraph.replace_all internally, + # so the returned node is already in the fgraph + new_inplace_node = optimizer.try_inplace_on_node( + fgraph, + stripped_node, + candidate_pairs=candidate_pairs, + reason="fuse_indexed_elemwise_inplace_read_buffers", + ) + worklist.append(new_inplace_node) + continue + idx_vars = [idx for idx, _axis in idx_groups] + # The strip-inplace pass above guarantees that indexed-write outputs + # carry no inplace + assert not any( + out_idx in write_targets for out_idx in node.op.inplace_pattern + ) fgraph_destroy_map = { out_idx: [inp_idx] for out_idx, inp_idx in node.op.inplace_pattern.items() - if out_idx not in write_targets } # Fgraph inputs: substitute indexed sources back to their diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 22ffbe5927..e7971e68bb 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -5,6 +5,7 @@ import pytensor.tensor as pt from pytensor import Mode, function, get_mode +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -360,6 +361,71 @@ def test_inc_subtensor(self): fn(xv, yv, tv.copy()), fn_u(xv, yv, tv.copy()), rtol=1e-10 ) + def test_write_of_inplace_elemwise(self): + """Inplace must not survive on a write-target output. + + The loop writes the elementwise result to the write buffer, never to the + inplaced input, so a surviving inner inplace entry would destroy the dot + intermediate undeclared in the Python-mode fallback. + """ + rng = np.random.default_rng(9) + idx = rng.integers(8, size=30).astype(np.int64) + x = pt.matrix("x", shape=(30, 4)) + y = pt.matrix("y", shape=(4, 4)) + z = pt.matrix("z", shape=(30, 4)) + t = pt.matrix("t", shape=(8, 4)) + out = t[idx].inc((x @ y) * z) + fn, fn_u = fused_and_unfused([x, y, z, t], out) + assert_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + ] + assert not any( + n.op.inplace_pattern + for n in node.op.fgraph.toposort() + if isinstance(n.op, Elemwise) + ) + xv, yv, zv, tv = ( + rng.normal(size=(30, 4)), + rng.normal(size=(4, 4)), + rng.normal(size=(30, 4)), + rng.normal(size=(8, 4)), + ) + np.testing.assert_allclose( + fn(xv, yv, zv, tv.copy()), fn_u(xv, yv, zv, tv.copy()), rtol=1e-10 + ) + + def test_write_with_direct_use_keeps_inplace(self): + """Inplace survives on an output that is both written and used directly. + + The write-and-direct duplication keeps the original output materialized + (the write consumes a duplicate), so the inplace claimed on the dot + intermediate stays valid and must not be stripped. + """ + rng = np.random.default_rng(10) + idx = rng.integers(8, size=30).astype(np.int64) + x = pt.matrix("x", shape=(30, 4)) + y = pt.matrix("y", shape=(4, 4)) + z = pt.matrix("z", shape=(30, 4)) + t = pt.matrix("t", shape=(8, 4)) + w = (x @ y) * z + fn, fn_u = fused_and_unfused([x, y, z, t], [t[idx].inc(w), w]) + assert_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + ] + # Two destroy entries: the write buffer, and the dot intermediate kept + # inplace by the materialized output + assert len(node.op.destroy_map) == 2 + xv, yv, zv, tv = ( + rng.normal(size=(30, 4)), + rng.normal(size=(4, 4)), + rng.normal(size=(30, 4)), + rng.normal(size=(8, 4)), + ) + for res, res_u in zip(fn(xv, yv, zv, tv.copy()), fn_u(xv, yv, zv, tv.copy())): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + def test_set_subtensor(self): rng = np.random.default_rng(42) idx = rng.integers(85, size=919).astype(np.int64) From 3faa3ba549dbf78c813625117711f0be64911855 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 15:50:57 +0200 Subject: [PATCH 2/4] Extend IndexedElemwise to also fuse reductions --- pytensor/link/numba/dispatch/blockwise.py | 2 + pytensor/link/numba/dispatch/elemwise.py | 257 ++++++++++++++-- pytensor/link/numba/dispatch/random.py | 2 + .../link/numba/dispatch/vectorize_codegen.py | 102 ++++--- ...{indexed_elemwise.py => fused_elemwise.py} | 246 +++++++++++++--- pytensor/tensor/rewriting/numba.py | 2 +- tests/benchmarks/test_gather_fusion.py | 18 +- ...xed_elemwise.py => test_fused_elemwise.py} | 277 +++++++++++++++++- tests/tensor/rewriting/test_elemwise.py | 6 +- tests/tensor/rewriting/test_math.py | 20 +- tests/tensor/rewriting/test_subtensor.py | 7 +- tests/tensor/rewriting/test_uncanonicalize.py | 8 +- 12 files changed, 807 insertions(+), 140 deletions(-) rename pytensor/tensor/rewriting/{indexed_elemwise.py => fused_elemwise.py} (77%) rename tests/link/numba/{test_indexed_elemwise.py => test_fused_elemwise.py} (72%) diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index af4aa4aee3..49fc4b43f6 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -13,6 +13,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -96,6 +97,7 @@ def impl(*inputs_and_core_shapes): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1383ee1df1..06a3a7f419 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -29,6 +29,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -50,7 +51,7 @@ from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise @singledispatch @@ -140,6 +141,38 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr): ] +# Scalar reduce ops that ``accumulate_into_slice`` (and thus fused reductions / +# indexed inc) supports, paired with the in-place form used on the output slice. +# Augmented assignment is used where Numba supports it directly on 0-d arrays +# (``out`` is the core output slice, which is 0-d for scalar cores); Maximum and +# Minimum lack an augmented operator so they use a ufunc + ellipsis-set instead. +_SLICE_ACCUMULATE = { + Add: "{out} += {inner}", + Mul: "{out} *= {inner}", + AND: "{out} &= {inner}", + OR: "{out} |= {inner}", + XOR: "{out} ^= {inner}", + Maximum: "{out}[...] = np.maximum({out}, {inner})", + Minimum: "{out}[...] = np.minimum({out}, {inner})", +} + + +def accumulate_into_slice(scalar_op, out: str, inner: str) -> list[str]: + """In-place accumulation lines for a (possibly 0-d) output slice ``out``. + + Unlike ``scalar_in_place_fn`` (which indexes ``res[idx]`` and breaks on 0-d + cores), this operates on the whole slice so it is valid for both scalar + (0-d) and array cores. Used by both fused reductions and indexed ``inc``. + """ + try: + template = _SLICE_ACCUMULATE[type(scalar_op)] + except KeyError: + raise NotImplementedError( + f"No fused-reduction accumulation for scalar op {scalar_op}" + ) + return [template.format(out=out, inner=inner)] + + @intrinsic def _address_as_void_pointer(typingctx, src): """Returns a void pointer from a given memory address.""" @@ -695,6 +728,7 @@ def impl(*inputs): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl @@ -715,8 +749,88 @@ def impl(*inputs): return elemwise, elemwise_key -@register_funcify_and_cache_key(IndexedElemwise) -def numba_funcify_IndexedElemwise(op, node, **kwargs): +def _reduce_identity(identity, acc_dtype): + """Coerce a reduction identity to a concrete numpy scalar in ``acc_dtype``. + + Non-finite identities (``±inf`` for max/min) become the dtype's bounds for + integer accumulators, mirroring ``create_multiaxis_reducer``. + """ + acc_dtype = np.dtype(acc_dtype) + if acc_dtype.kind in "ui" and not np.isfinite(identity): + identity = ( + np.iinfo(acc_dtype).max + if np.isposinf(identity) + else np.iinfo(acc_dtype).min + ) + return acc_dtype.type(identity) + + +def _build_reduce_impl_src(nout, post_specs): + """Build the ``@overload`` impl source for a fused op with reduction outputs. + + Calls ``_vectorized`` (which produces a keepdims, size-1-on-reduced-axes + buffer in ``acc_dtype``) then, per reduction output, squeezes the reduced + axes back out and casts to the true output dtype. ``post_specs[i]`` is + ``None`` for a passthrough output, else ``(kept_axes, out_dtype, cast_needed)``. + """ + code: list[str | CODE_TOKEN] = [ + "def fused_elemwise_impl(*outer_inputs):", + CODE_TOKEN.INDENT, + "raw = _vectorized(", + CODE_TOKEN.INDENT, + "core_op_fn,", + "input_bc_patterns_enc,", + "output_bc_patterns_enc,", + "output_dtypes_enc,", + "inplace_pattern_enc,", + "True,", + "(),", + "outer_inputs,", + "core_output_shapes,", + "NO_SIZE,", + "indexed_inputs_enc,", + "indexed_outputs_enc,", + "reduce_outputs_enc,", + CODE_TOKEN.DEDENT, + ")", + ] + + def src_access(i): + return "raw" if nout == 1 else f"raw[{i}]" + + out_syms = [] + for i in range(nout): + sym = f"o{i}" + out_syms.append(sym) + src = src_access(i) + spec = post_specs[i] + if spec is None: + code.append(f"{sym} = {src}") + continue + kept_axes, out_dtype, cast_needed = spec + np_dtype = "bool_" if out_dtype == "bool" else out_dtype + if not kept_axes: + # Full reduction → 0-d array of the single accumulated value. + code.append(f"{sym} = np.array({src}.ravel()[0], dtype=np.{np_dtype})") + else: + shape_expr = create_tuple_string( + tuple(f"{src}.shape[{k}]" for k in kept_axes) + ) + expr = f"{src}.reshape({shape_expr})" + if cast_needed: + expr = f"{expr}.astype(np.{np_dtype})" + code.append(f"{sym} = {expr}") + + if nout == 1: + code.append(f"return {out_syms[0]}") + else: + code.append(f"return {create_tuple_string(tuple(out_syms))}") + code.append(CODE_TOKEN.DEDENT) + return build_source_code(code) + + +@register_funcify_and_cache_key(FusedElemwise) +def numba_funcify_FusedElemwise(op, node, **kwargs): """Generate fused Elemwise Numba code with indexed reads and updates. Reads indexed_inputs/indexed_outputs specs stored on the Op by the @@ -739,6 +853,7 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): n_indices = len(indexed_inputs) nin_elemwise = len(elemwise_node.inputs) nout = len(elemwise_node.outputs) + reduced_outputs = op.reduced_outputs or ((None,) * nout) inc_outputs = frozenset( out_idx @@ -747,14 +862,58 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): for out_idx in entry[0] if entry[2] == "inc" ) + reduced_idxs = frozenset(i for i, r in enumerate(reduced_outputs) if r is not None) + assert not (inc_outputs & reduced_idxs), ( + "An output cannot be both an indexed write and a reduction" + ) + + # accum_fns bake per-output in-place accumulation into store_core_outputs: + # indexed `inc` writes accumulate with +=; reductions use their scalar op + # (Add/Mul/Maximum/...). Both go through accumulate_into_slice so they are + # valid on 0-d (scalar core) slices. Disjoint output indices. + accum_fns: dict = {} + for out_idx in inc_outputs: + accum_fns[out_idx] = lambda out, inner: [f"{out} += {inner}"] + for i in reduced_idxs: + accum_fns[i] = lambda out, inner, _op=reduced_outputs[i][0]: ( + accumulate_into_slice(_op, out, inner) + ) core_op_fn = store_core_outputs( - scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + scalar_op_fn, nin=nin_elemwise, nout=nout, accum_fns=accum_fns ) input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) - output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) - output_dtypes = tuple(out.type.dtype for out in node.outputs) + + # Reduction outputs keep their reduced axes as bc=True (size-1, keepdims) and + # are accumulated in acc_dtype; post_specs squeezes + casts them back to the + # true CAReduce output below. + output_bc_patterns_list = [] + output_dtypes_list = [] + reduce_identities = [] + post_specs: list = [] + has_reductions = bool(reduced_idxs) + for i, inner_out in enumerate(elemwise_node.outputs): + spec = reduced_outputs[i] + if spec is None: + output_bc_patterns_list.append(inner_out.type.broadcastable) + output_dtypes_list.append(node.outputs[i].type.dtype) + post_specs.append(None) + continue + _reduce_op, axes, identity, acc_dtype = spec + bc = list(inner_out.type.broadcastable) + for ax in axes: + bc[ax] = True + output_bc_patterns_list.append(tuple(bc)) + output_dtypes_list.append(str(np.dtype(acc_dtype))) + reduce_identities.append((i, _reduce_identity(identity, acc_dtype))) + kept_axes = tuple(d for d in range(len(bc)) if d not in axes) + out_dtype = node.outputs[i].type.dtype + cast_needed = np.dtype(acc_dtype) != np.dtype(out_dtype) + post_specs.append((kept_axes, out_dtype, cast_needed)) + + output_bc_patterns = tuple(output_bc_patterns_list) + output_dtypes = tuple(output_dtypes_list) inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) core_output_shapes = tuple(() for _ in range(nout)) @@ -768,52 +927,88 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): inplace_pattern_enc = encode_literals(inplace_pattern) indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable)) indexed_outputs_enc = encode_literals(indexed_outputs) + reduce_outputs_enc = encode_literals(tuple(reduce_identities)) + + def fused_elemwise_fn(*outer_inputs): + # Python-mode fallback (e.g. Numba's ``eval_python_only`` path, which + # runs the funcified function without JIT). Evaluate the inner fgraph + # faithfully, exactly like ``OpFromGraph.perform`` — so the fused op + # still produces valid results when executed directly in Python. The + # JIT path never reaches this body: the ``@overload`` below supplies the + # vectorized loop implementation. + res = op.fn(*outer_inputs) + return res[0] if len(res) == 1 else tuple(res) + + if not has_reductions: + + @overload(fused_elemwise_fn, jit_options=_jit_options) + def ov_fused_elemwise_fn(*outer_inputs): + def impl(*outer_inputs): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + outer_inputs, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + reduce_outputs_enc, + ) - def indexed_elemwise_fn(*outer_inputs): - raise NotImplementedError( - "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." + return impl + else: + impl_src = _build_reduce_impl_src(nout, post_specs) + impl_fn = compile_numba_function_src( + impl_src, + "fused_elemwise_impl", + { + **globals(), + "core_op_fn": core_op_fn, + "input_bc_patterns_enc": input_bc_patterns_enc, + "output_bc_patterns_enc": output_bc_patterns_enc, + "output_dtypes_enc": output_dtypes_enc, + "inplace_pattern_enc": inplace_pattern_enc, + "core_output_shapes": core_output_shapes, + "indexed_inputs_enc": indexed_inputs_enc, + "indexed_outputs_enc": indexed_outputs_enc, + "reduce_outputs_enc": reduce_outputs_enc, + }, ) - @overload(indexed_elemwise_fn, jit_options=_jit_options) - def ov_indexed_elemwise_fn(*outer_inputs): - def impl(*outer_inputs): - return _vectorized( - core_op_fn, - input_bc_patterns_enc, - output_bc_patterns_enc, - output_dtypes_enc, - inplace_pattern_enc, - True, # allow_core_scalar - (), # constant_inputs - outer_inputs, - core_output_shapes, - NO_SIZE, - indexed_inputs_enc, - indexed_outputs_enc, - ) - - return impl + @overload(fused_elemwise_fn, jit_options=_jit_options) + def ov_fused_elemwise_fn(*outer_inputs): + return impl_fn - cache_version = 2 + cache_version = 3 if scalar_cache_key is None: key = None else: + reduced_key = tuple( + (type(r[0]).__name__, r[1], str(np.dtype(r[3]))) if r is not None else None + for r in reduced_outputs + ) key = str( ( type(op), - "IndexedElemwise", + "FusedElemwise", cache_version, inplace_pattern, input_bc_patterns, indexed_inputs, idx_broadcastable, indexed_outputs, + reduced_key, scalar_cache_key, ) ) key = sha256(key.encode()).hexdigest() - return indexed_elemwise_fn, key + return fused_elemwise_fn, key @register_funcify_and_cache_key(CAReduce) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 28ade0bf3a..2f7f8b3700 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -23,6 +23,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -490,6 +491,7 @@ def impl(core_shape, rng, size, *dist_params): else numba_ndarray.to_fixed_tuple(size, size_len), NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return rng, draws diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index dfd12ed1b3..23fe491186 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -3,7 +3,6 @@ import base64 import pickle from collections.abc import Callable, Sequence -from textwrap import indent from typing import Any import numba @@ -17,6 +16,7 @@ from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.string_codegen import CODE_TOKEN, build_source_code def encode_literals(literals: Sequence) -> str: @@ -24,7 +24,11 @@ def encode_literals(literals: Sequence) -> str: def store_core_outputs( - core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() + core_op_fn: Callable, + nin: int, + nout: int, + accum_fns: dict[int, Callable[[str, str], Sequence[str | CODE_TOKEN]]] + | None = None, ) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @@ -35,14 +39,20 @@ def store_core_outputs( def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) o0[...] = to0 # direct outputs - o1 += to1 # inc outputs (in-place add works for 0d and Nd) + o1[...] += to1 # accumulating outputs (reduce / indexed-inc) ... - ``inc_outputs`` lists output indices that use ``+=`` instead of ``=``. + ``accum_fns`` maps an output index to a callable ``(out_sym, inner_sym) -> + lines`` producing the in-place accumulation code for that output (e.g. + ``["o1[...] += t1"]`` for a sum reduction, or the multi-line conditional for + a max reduction). Outputs absent from ``accum_fns`` are stored with ``=``. + Both reductions and indexed ``inc`` writes go through this mechanism. """ if getattr(core_op_fn, "handles_out", False): return core_op_fn + accum_fns = accum_fns or {} + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] @@ -50,19 +60,22 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): inp_signature = ", ".join(inputs) out_signature = ", ".join(outputs) inner_out_signature = ", ".join(inner_outputs) - store_outputs = "\n".join( - f"{output} += {inner_output}" - if i in inc_outputs - else f"{output}[...] = {inner_output}" - for i, (output, inner_output) in enumerate( - zip(outputs, inner_outputs, strict=True) - ) - ) - func_src = f""" -def store_core_outputs({inp_signature}, {out_signature}): - {inner_out_signature} = core_op_fn({inp_signature}) -{indent(store_outputs, " " * 4)} -""" + + code: list[str | CODE_TOKEN] = [ + f"def store_core_outputs({inp_signature}, {out_signature}):", + CODE_TOKEN.INDENT, + f"{inner_out_signature} = core_op_fn({inp_signature})", + ] + for i, (output, inner_output) in enumerate( + zip(outputs, inner_outputs, strict=True) + ): + if i in accum_fns: + code.extend(accum_fns[i](output, inner_output)) + else: + code.append(f"{output}[...] = {inner_output}") + code.append(CODE_TOKEN.DEDENT) + + func_src = build_source_code(code) global_env = {"core_op_fn": core_op_fn} func = compile_numba_function_src( @@ -249,6 +262,7 @@ def _codegen_return_outputs( NO_INDEXED_INPUTS = encode_literals(((), ())) NO_INDEXED_OUTPUTS = encode_literals(()) +NO_REDUCE_OUTPUTS = encode_literals(()) NO_SIZE = None @@ -355,11 +369,17 @@ def make_outputs( input_types: tuple[Any, ...], output_core_shapes: tuple, update_outputs: dict | None = None, + reduce_identities: dict | None = None, ) -> tuple[list[ir.Value], list[types.Array]]: """Allocate output arrays for vectorized loop. ``update_outputs`` maps ``{output_idx: (array, array_type)}`` for outputs that reuse an indexed-write target buffer instead of being freshly allocated. + + ``reduce_identities`` maps ``{output_idx: identity_value}`` for reduction + outputs. Such outputs are freshly allocated (size 1 on the reduced axes, via + their ``bc=True`` pattern) and pre-filled with the reduction identity so the + accumulating store in the loop reduces into them correctly. """ output_arrays = [] output_arry_types = [] @@ -389,6 +409,22 @@ def make_outputs( ] shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) + if reduce_identities is not None and i in reduce_identities: + # Pre-fill the freshly allocated (C-contiguous) buffer with the + # reduction identity. A flat scan over every element is valid + # regardless of which axes are reduced, and seeds each kept-axis + # accumulator cell (size-1 reduced axes included). + nitems = ir.IntType(64)(1) + for dim_len in shape: + nitems = builder.mul(nitems, dim_len) + ident = ctx.get_constant(dtype, reduce_identities[i]) + # bool is an i1 value but stored as i8 in arrays; widen to the + # buffer's element type so the store types match. + elem_ty = array.data.type.pointee + if ident.type != elem_ty: + ident = builder.zext(ident, elem_ty) + with cgutils.for_range(builder, nitems) as loop: + builder.store(ident, builder.gep(array.data, [loop.index])) output_arrays.append(array) # If there is no inplace operation, we know that all output arrays @@ -423,8 +459,6 @@ def make_loop_call( ): safe = (False, False) - n_outputs = len(outputs) - # TODO I think this is better than the noalias attribute # for the input, but self_ref isn't supported in a released # llvmlite version yet @@ -449,21 +483,13 @@ def _wrap_negative_index(idx_val, dim_size, signed): wrapped = builder.add(idx_val, dim_size) return builder.select(is_neg, wrapped, idx_val) - # Setup loops and initialize accumulators for outputs - # This part corresponds to opening the loops + # Open one loop per iteration dimension. Reduction outputs need no special + # setup here: they carry ``bc=True`` on the reduced axes, so the write_idx + # logic below points every iteration over a reduced axis at memory index 0 + # (the same cell), and the accumulating store reduces into it. loop_stack = [] loops = [] - output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs - for dim, length in enumerate(iter_shape): - # Find outputs that only have accumulations left - for out in range(n_outputs): - if output_accumulator[out][0] is not None: - continue - if all(output_bc[out][dim:]): - value = outputs[out][0].type.pointee(0) - accu = cgutils.alloca_once_value(builder, value) - output_accumulator[out] = (accu, dim) - + for length in iter_shape: loop = cgutils.for_range(builder, length) loop_stack.append(loop) loops.append(loop.__enter__()) @@ -736,6 +762,7 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ): """Vectorized intrinsic with optional indirect indexing for reads and writes. @@ -751,7 +778,12 @@ def _vectorized( ``((out_0, out_1), mode)`` means that index updates outputs out_0 and out_1 with *mode* ``"set"`` or ``"inc"``. - For non-indexed calls, both are ``()``. + ``reduce_outputs`` lists ``(output_idx, identity)`` pairs for reduction + outputs. Such an output carries ``bc=True`` on its reduced axes; the buffer + is allocated size 1 there, pre-filled with ``identity``, and the per-iteration + store (baked into ``core_func`` via ``store_core_outputs``) accumulates into it. + + For non-indexed/non-reducing calls, these are ``()``. """ arg_types = [ core_func, @@ -766,12 +798,14 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ] input_bc_patterns = _decode_literal(input_bc_patterns, "input_bc_patterns") output_bc_patterns = _decode_literal(output_bc_patterns, "output_bc_patterns") output_dtypes = _decode_literal(output_dtypes, "output_dtypes") inplace_pattern = _decode_literal(inplace_pattern, "inplace_pattern") + reduce_identities = dict(_decode_literal(reduce_outputs, "reduce_outputs")) indexed_inputs, idx_broadcastable = _decode_literal( indexed_inputs, "indexed_inputs" ) @@ -919,6 +953,7 @@ def codegen(ctx, builder, sig, args): size, _, _, + _, ] = args constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) @@ -1053,6 +1088,7 @@ def codegen(ctx, builder, sig, args): source_input_types, output_core_shapes, update_outputs=update_outputs_dict, + reduce_identities=reduce_identities, ) core_signature = typingctx.resolve_function_type( diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/fused_elemwise.py similarity index 77% rename from pytensor/tensor/rewriting/indexed_elemwise.py rename to pytensor/tensor/rewriting/fused_elemwise.py index 97ef064fcc..f97f3e8c28 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/fused_elemwise.py @@ -1,9 +1,9 @@ -"""Fuse indexed reads and updates into Elemwise iteration loops. +"""Fuse indexed reads/writes and reductions into Elemwise iteration loops. -Introduces ``IndexedElemwise``, an ``OpFromGraph`` that wraps -``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` subgraphs -so the Numba backend can generate a single loop with indirect indexing, -eliminating materialised intermediate arrays. +Introduces ``FusedElemwise``, an ``OpFromGraph`` that wraps +``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` / ``CAReduce`` +subgraphs so the Numba backend can generate a single loop with indirect +indexing and inline accumulation, eliminating materialised intermediate arrays. """ from pytensor.compile import optdb @@ -14,8 +14,8 @@ from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.utils import InconsistencyError from pytensor.printing import op_debug_information -from pytensor.scalar.basic import Composite -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.scalar.basic import AND, OR, XOR, Add, Composite, Maximum, Minimum, Mul +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( @@ -27,6 +27,12 @@ from pytensor.tensor.variable import TensorVariable +# CAReduce scalar ops whose reduction the Numba backend can fuse into the loop +# (those for which the codegen has an in-place accumulation: see +# ``accumulate_into_slice`` in the numba elemwise dispatch). +_REDUCE_SCALAR_OPS = (Add, Mul, Maximum, Minimum, AND, OR, XOR) + + def _view_root(view_i, var): """Follow the destroy-handler view chain to the underlying buffer. @@ -102,7 +108,7 @@ def undo_take_dimshuffle_for_fusion(fgraph, node): The ``local_replace_AdvancedSubtensor`` specialize rewrite converts ``x[:, idx]`` into ``x.T[idx].T`` (axis-swap + AdvancedSubtensor1 + axis-swap). This rewrite undoes that when the result feeds a single - Elemwise, so ``FuseIndexedElemwise`` can absorb the indexing directly + Elemwise, so ``FuseElemwise`` can absorb the indexing directly on the correct axis. See also ``undo_take_reshape_for_fusion`` which handles the analogous @@ -135,7 +141,7 @@ def undo_take_reshape_for_fusion(fgraph, node): ``transform_take`` rewrites ``x[mat_idx]`` (ND integer index) into ``AdvancedSubtensor1(x, mat_idx.ravel()).reshape(mat_idx.shape + ...)``, possibly with DimShuffle axis-swaps for non-zero axes. This rewrite - undoes that so ``FuseIndexedElemwise`` can absorb the ND index directly. + undoes that so ``FuseElemwise`` can absorb the ND index directly. """ [reshape_out] = node.outputs @@ -169,10 +175,12 @@ def undo_take_reshape_for_fusion(fgraph, node): return [new_out] -indexed_elemwise_optdb = SequenceDB() +fused_elemwise_optdb = SequenceDB() optdb.register( + # Predates the FusedElemwise rename; kept so existing .including/.excluding + # calls targeting this name keep working. "fuse_indexed_into_elemwise", - indexed_elemwise_optdb, + fused_elemwise_optdb, "numba", # symbolic_op_recognition is excluded from OpFromGraph inner-graph # compilation, preventing recursive fusion. @@ -182,14 +190,14 @@ def undo_take_reshape_for_fusion(fgraph, node): position=100, ) -indexed_elemwise_optdb.register( +fused_elemwise_optdb.register( "undo_take_dimshuffle_for_fusion", dfs_rewriter(undo_take_dimshuffle_for_fusion), "numba", position=0, ) -indexed_elemwise_optdb.register( +fused_elemwise_optdb.register( "undo_take_reshape_for_fusion", dfs_rewriter(undo_take_reshape_for_fusion), "numba", @@ -197,7 +205,7 @@ def undo_take_reshape_for_fusion(fgraph, node): ) -class IndexedElemwise(OpFromGraph): +class FusedElemwise(OpFromGraph): """Fuse indexed reads and updates into a single Elemwise iteration loop. Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) and @@ -250,28 +258,66 @@ class IndexedElemwise(OpFromGraph): Examples:: tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")] + + reduced_outputs : tuple of ((scalar_op, axes, identity, acc_dtype) | None) + One entry per (inner Elemwise / outer op) output position. + ``None`` if the output is not a reduction. + Otherwise the output is the result of reducing the inner Elemwise output + with ``CAReduce(scalar_op)`` over ``axes``: + + - ``scalar_op``: the commutative/associative binary scalar op of the + reduction (e.g. ``add`` for sum, ``mul`` for prod, ``maximum`` for max). + - ``axes``: tuple of reduced batch axes (in the inner Elemwise's dim space). + - ``identity``: the reduction identity, used to seed the accumulator buffer. + - ``acc_dtype``: dtype the accumulation is carried out in (the Numba loop + accumulates in this dtype, then casts to the output dtype). + + The inner fgraph still holds a faithful ``CAReduce(Elemwise(...))`` so + non-Numba backends evaluate correctly via ``OpFromGraph.perform``; only + the Numba backend reads this spec to fuse the reduction into the loop. + + Examples:: + + sum(exp(x)) → reduced_outputs=[(add, (0,), 0.0, "float64")] """ - def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): + def __init__( + self, + *args, + indexed_inputs=(), + indexed_outputs=(), + reduced_outputs=(), + **kwargs, + ): self.indexed_inputs = indexed_inputs self.indexed_outputs = indexed_outputs + self.reduced_outputs = reduced_outputs # A read buffer can occupy multiple input slots (e.g. read through # several indices); construct_nominal_fgraph dedupes those to one # nominal, leaving the extra slots as unused NominalVariables, which is # safe because reads don't destroy. Write targets always get their own - # fresh inner input (see FuseIndexedElemwise) so a destroyed buffer is + # fresh inner input (see FuseElemwise) so a destroyed buffer is # never deduped onto a read source. super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) def __str__(self): + elemwise_str = "Elemwise" for node in self.fgraph.apply_nodes: if isinstance(node.op, Elemwise): - return f"IndexedElemwise{{{node.op!s}}}" - return "IndexedElemwise" - - -@op_debug_information.register(IndexedElemwise) -def _op_debug_information_IndexedElemwise(op, node): + elemwise_str = str(node.op) + break + reductions = [ + f"{type(spec[0]).__name__.lower()}@{spec[1]}" + for spec in self.reduced_outputs + if spec is not None + ] + if reductions: + return f"FusedElemwise{{{elemwise_str}, reduce[{', '.join(reductions)}]}}" + return f"FusedElemwise{{{elemwise_str}}}" + + +@op_debug_information.register(FusedElemwise) +def _op_debug_information_FusedElemwise(op, node): info = {} n_idx = len(op.indexed_inputs) @@ -317,10 +363,20 @@ def _op_debug_information_IndexedElemwise(op, node): f"indexed {mode} ({buf_label}, {idx_label})" ) + # Annotate reduced outputs + for out_idx, spec in enumerate(op.reduced_outputs): + if spec is None or out_idx >= len(node.outputs): + continue + scalar_op, axes, _identity, acc_dtype = spec + info[node.outputs[out_idx]] = ( + f"reduced[{type(scalar_op).__name__.lower()}] " + f"over axes {axes} acc={acc_dtype}" + ) + return {node: info} -class FuseIndexedElemwise(GraphRewriter): +class FuseElemwise(GraphRewriter): """Fuse indexed reads and indexed updates into Elemwise loops. Absorbs single-client ``AdvancedSubtensor1`` on inputs (indexed reads) @@ -578,7 +634,25 @@ def apply(self, fgraph): idx_groups[idx_axis_pair][1].append(out_idx) write_targets[out_idx] = client_node - if not idx_groups: + # Find reductions to fuse: an Elemwise output whose sole client is an + # eligible CAReduce. Outputs that are write targets are excluded (an + # output can't be both an indexed write and a reduction). Outputs + # with extra (non-reduce) clients are handled by duplication below. + reduced_outputs = {} # out_idx -> car_node + for out_idx, out in enumerate(node.outputs): + if out_idx in write_targets: + continue + car_clients = [ + c + for c, _ in fgraph.clients[out] + if isinstance(c.op, CAReduce) + and isinstance(c.op.scalar_op, _REDUCE_SCALAR_OPS) + ] + if len(car_clients) != 1: + continue + reduced_outputs[out_idx] = car_clients[0] + + if not idx_groups and not reduced_outputs: continue if must_transpose_write_axes: @@ -588,18 +662,53 @@ def apply(self, fgraph): assert replacements fgraph.replace_all( replacements, - reason="fuse_indexed_elemwise_move_write_axes", + reason="fuse_elemwise_move_write_axes", ) worklist.append(node) continue + # If a reduced output also feeds non-reduce consumers, duplicate it via + # Composite so the reduction consumes the duplicate while the original + # stays materialised for the other consumers. We still fuse the + # reduction loop, even if the full output must also be produced. + # This runs before the strip-inplace pass below, so any reduced output + # reaching it is sole-client (truly de-materialized) and an inplace on a + # materialized original survives the fusion. + def _has_non_reduce_clients(out_idx): + car_node = reduced_outputs[out_idx] + return any( + c is not car_node for c, _ in fgraph.clients[node.outputs[out_idx]] + ) + + if reduce_and_direct_use_outs := { + out_idx + for out_idx in reduced_outputs + if _has_non_reduce_clients(out_idx) + }: + new_node, dup_map = self._duplicate_multi_client_outputs( + node, reduce_and_direct_use_outs + ) + replacements = list( + zip(node.outputs, new_node.outputs[: len(node.outputs)]) + ) + for out_idx, dup_idx in dup_map.items(): + car_node = reduced_outputs[out_idx] + new_reduced = car_node.op(new_node.outputs[dup_idx]) + replacements.append((car_node.outputs[0], new_reduced)) + fgraph.replace_all( + replacements, + reason="fuse_reduce_and_direct_outputs", + ) + worklist.append(new_node) + continue + # If any indexed-write output also has other consumers, # duplicate it via Composite so the write replaces the duplicate # while the original stays available for non-write consumers. # We still avoid one extra write loop, # even if we can't skip the output materialization altogether. - # This runs before the strip-inplace pass below, so an inplace on the - # materialized original survives the fusion. + # Like the reduce duplication above, this runs before the strip-inplace + # pass, so an inplace on the materialized original survives the fusion. def _has_non_write_clients(out_idx): update = write_targets[out_idx] for c, _ in fgraph.clients[node.outputs[out_idx]]: @@ -635,26 +744,31 @@ def _has_non_write_clients(out_idx): replacements.append((update_node.outputs[0], new_update_out)) fgraph.replace_all( replacements, - reason="fuse_indexed_elemwise_write_and_direct_outputs", + reason="fuse_elemwise_write_and_direct_outputs", ) worklist.append(new_node) continue indexed_reads = {i for reads, _ in idx_groups.values() for i in reads} - # If any inplace targets an indexed-read input, or claims an indexed-write - # output (the loop writes the result to the write buffer instead, so the - # input destruction would happen only in the Python-mode fallback, - # undeclared by the outer destroy map), strip and re-run inplace with - # those inputs protected and outputs excluded. The duplication above ran - # first, so write-target outputs here are sole-client. + # If any inplace targets an indexed-read input, or claims an output that is + # not materialized as a plain output — a fused reduction (reusing the input + # buffer as the reduce accumulator would skip the identity init and + # accumulate on top of stale input values) or an indexed write (the loop + # writes to the write buffer instead, so the input destruction would happen + # only in the Python-mode fallback, undeclared by the outer destroy map) — + # strip and re-run inplace with those inputs protected and outputs excluded. + # The duplications above ran first, so such outputs here are sole-client. if any( inp_idx in indexed_reads for inp_idx in node.op.inplace_pattern.values() - ) or any(out_idx in write_targets for out_idx in node.op.inplace_pattern): + ) or any( + out_idx in reduced_outputs or out_idx in write_targets + for out_idx in node.op.inplace_pattern + ): stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) fgraph.replace_all( zip(node.outputs, stripped_node.outputs), - reason="fuse_indexed_elemwise_strip_inplace", + reason="fuse_elemwise_strip_inplace", ) optimizer = InplaceElemwiseOptimizer() protected = optimizer._get_protected_inputs(fgraph) @@ -666,7 +780,8 @@ def _has_non_write_clients(out_idx): for pair in optimizer.filter_candidate_pairs( fgraph, stripped_node, protected ) - if pair[0][0] not in write_targets + if pair[0][0] not in reduced_outputs + and pair[0][0] not in write_targets ] # try_inplace_on_node does its own fgraph.replace_all internally, # so the returned node is already in the fgraph @@ -674,17 +789,18 @@ def _has_non_write_clients(out_idx): fgraph, stripped_node, candidate_pairs=candidate_pairs, - reason="fuse_indexed_elemwise_inplace_read_buffers", + reason="fuse_elemwise_inplace_read_buffers", ) worklist.append(new_inplace_node) continue idx_vars = [idx for idx, _axis in idx_groups] - # The strip-inplace pass above guarantees that indexed-write outputs - # carry no inplace + # The strip-inplace pass above guarantees that reduced and indexed-write + # outputs carry no inplace assert not any( - out_idx in write_targets for out_idx in node.op.inplace_pattern + out_idx in reduced_outputs or out_idx in write_targets + for out_idx in node.op.inplace_pattern ) fgraph_destroy_map = { out_idx: [inp_idx] @@ -740,6 +856,41 @@ def _has_non_write_clients(out_idx): fgraph_outputs[out_idx] = write_out fgraph_destroy_map[out_idx] = [target_pos] + # Inner fgraph reduced outputs: wrap the Elemwise output in the real + # CAReduce so non-Numba backends compute it faithfully via perform. + # reduced_spec carries the (scalar_op, axes, identity, acc_dtype) the + # Numba backend reads to fuse the reduction into the loop instead. + reduced_spec_by_idx = {} + for out_idx, car_node in sorted(reduced_outputs.items()): + car_op = car_node.op + ndim = node.outputs[out_idx].type.ndim + axes = ( + tuple(range(ndim)) + if car_op.axis is None + else tuple(sorted(car_op.axis)) + ) + acc_dtype = ( + car_op.acc_dtype + if car_op.acc_dtype is not None + else car_node.outputs[0].type.dtype + ) + reduced_spec_by_idx[out_idx] = ( + car_op.scalar_op, + axes, + car_op.scalar_op.identity, + acc_dtype, + ) + fgraph_outputs[out_idx] = car_op(node.outputs[out_idx]) + + reduced_spec = ( + tuple( + reduced_spec_by_idx.get(out_idx) + for out_idx in range(len(node.outputs)) + ) + if reduced_outputs + else () + ) + # indexed_inputs_spec: ((read_positions, axis) | None, ...) # indexed_outputs_spec: ((write_positions, axis, "inc"|"set") | None, ...) indexed_inputs_spec = tuple( @@ -762,12 +913,13 @@ def _has_non_write_clients(out_idx): val = outer_write_targets.get(i, inp) outer_inputs.append(val.copy() if i in copy_positions else val) - new_outs = IndexedElemwise( + new_outs = FusedElemwise( fgraph_inputs, fgraph_outputs, destroy_map=fgraph_destroy_map, indexed_inputs=indexed_inputs_spec, indexed_outputs=indexed_outputs_spec, + reduced_outputs=reduced_spec, )(*outer_inputs, return_list=True) replacements = [] @@ -776,6 +928,10 @@ def _has_non_write_clients(out_idx): replacements.append( (write_targets[out_idx].outputs[0], new_outs[out_idx]) ) + elif out_idx in reduced_outputs: + replacements.append( + (reduced_outputs[out_idx].outputs[0], new_outs[out_idx]) + ) else: replacements.append((node.outputs[out_idx], new_outs[out_idx])) @@ -791,9 +947,9 @@ def _has_non_write_clients(out_idx): continue -indexed_elemwise_optdb.register( - "fuse_indexed_elemwise", - FuseIndexedElemwise(), +fused_elemwise_optdb.register( + "fuse_elemwise", + FuseElemwise(), "numba", position=1, ) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index ecb3435030..d2c41aeef3 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -1,4 +1,4 @@ -import pytensor.tensor.rewriting.indexed_elemwise # noqa: F401 +import pytensor.tensor.rewriting.fused_elemwise # noqa: F401 from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter diff --git a/tests/benchmarks/test_gather_fusion.py b/tests/benchmarks/test_gather_fusion.py index 0c634e354d..46d9446343 100644 --- a/tests/benchmarks/test_gather_fusion.py +++ b/tests/benchmarks/test_gather_fusion.py @@ -12,7 +12,7 @@ import pytensor.tensor as pt from pytensor import config from pytensor.compile.mode import get_mode -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 @@ -50,11 +50,11 @@ def read_benchmark_setup(request): ) assert any( - isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() - ), "IndexedElemwise not found in fused graph" + isinstance(n.op, FusedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "FusedElemwise not found in fused graph" assert not any( - isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() - ), "IndexedElemwise found in unfused graph" + isinstance(n.op, FusedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "FusedElemwise found in unfused graph" rng = np.random.default_rng(1) vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] @@ -117,15 +117,15 @@ def write_benchmark_setup(request): ) assert any( - isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() - ), "IndexedElemwise not found in fused graph" + isinstance(n.op, FusedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "FusedElemwise not found in fused graph" assert not any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn_fused.maker.fgraph.toposort() ), "AdvancedIncSubtensor1 still present in fused graph" assert not any( - isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() - ), "IndexedElemwise found in unfused graph" + isinstance(n.op, FusedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "FusedElemwise found in unfused graph" rng = np.random.default_rng(1) vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_fused_elemwise.py similarity index 72% rename from tests/link/numba/test_indexed_elemwise.py rename to tests/link/numba/test_fused_elemwise.py index e7971e68bb..74204e3200 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_fused_elemwise.py @@ -1,4 +1,4 @@ -"""Tests for IndexedElemwise fusion (indexed reads and updates in Elemwise loops).""" +"""Tests for FusedElemwise (indexed reads/writes and reductions fused into Elemwise loops).""" import numpy as np import pytest @@ -6,7 +6,7 @@ import pytensor.tensor as pt from pytensor import Mode, function, get_mode from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, AdvancedSubtensor, @@ -27,9 +27,9 @@ def fused_and_unfused(inputs, output): def assert_fused(fn): - """Assert that the compiled graph contains an IndexedElemwise node.""" - assert any(isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort()), ( - "IndexedElemwise not found in fused graph" + """Assert that the compiled graph contains a FusedElemwise node.""" + assert any(isinstance(n.op, FusedElemwise) for n in fn.maker.fgraph.toposort()), ( + "FusedElemwise not found in fused graph" ) @@ -252,7 +252,7 @@ def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): fn, fn_u = fused_and_unfused([x, y, target], out) # Write not fused — the Elemwise loop dim is the non-indexed axis, # not the indexed axis. Read fusion may still create an - # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. + # FusedElemwise, but the AdvancedIncSubtensor1 must remain outside. assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() ) @@ -378,7 +378,7 @@ def test_write_of_inplace_elemwise(self): fn, fn_u = fused_and_unfused([x, y, z, t], out) assert_fused(fn) [node] = [ - n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) ] assert not any( n.op.inplace_pattern @@ -412,7 +412,7 @@ def test_write_with_direct_use_keeps_inplace(self): fn, fn_u = fused_and_unfused([x, y, z, t], [t[idx].inc(w), w]) assert_fused(fn) [node] = [ - n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) ] # Two destroy entries: the write buffer, and the dot intermediate kept # inplace by the materialized output @@ -532,7 +532,7 @@ def test_write_target_aliases_read_source(self, read_idx, write_idx): write_idx = np.array(write_idx, dtype=np.int64) out = b[write_idx].set(b[read_idx] * 2.0) fn, fn_u = fused_and_unfused([x], out) - # The read fuses into an IndexedElemwise; the aliasing write stays external. + # The read fuses into a FusedElemwise; the aliasing write stays external. assert_fused(fn) assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() @@ -566,7 +566,7 @@ def test_non_inplace_aliasing_write_preserves_input(self): # The inner write must destroy exactly the input the op's destroy_map names. [node] = [ - n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) ] [(_out_idx, [destroyed_pos])] = node.op.destroy_map.items() [inner_write] = [ @@ -696,3 +696,260 @@ def test_loop_shape_regression(self): res = f(beta=test_beta, mask=test_mask) ref_res = ref_f(beta=test_beta, mask=test_mask) np.testing.assert_allclose(res, ref_res, strict=True) + + +def assert_reduce_fused(fn): + """Assert the graph contains a FusedElemwise with a fused reduction.""" + nodes = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise)] + assert nodes, "FusedElemwise not found in fused graph" + assert any(any(r is not None for r in n.op.reduced_outputs) for n in nodes), ( + "No fused reduction (reduced_outputs) found" + ) + + +class TestReductionFusion: + """Reductions (CAReduce) fused into the Elemwise loop, no indexing.""" + + @pytest.mark.parametrize("axis", [None, 0, 1, 2, (0, 2), (0, 1), (1, 2)], ids=str) + def test_sum_axes(self, axis): + rng = np.random.default_rng(0) + x = pt.tensor3("x") + y = pt.tensor3("y") + out = pt.sum(pt.exp(x) + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(3, 4, 5)), rng.normal(size=(3, 4, 5)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_prod(self, axis): + rng = np.random.default_rng(1) + x = pt.matrix("x") + out = pt.prod(pt.exp(x * 0.1), axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(4, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-8) + + @pytest.mark.parametrize("reduce_fn", [pt.max, pt.min], ids=["max", "min"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_max_min(self, reduce_fn, axis): + rng = np.random.default_rng(2) + x = pt.matrix("x") + y = pt.matrix("y") + out = reduce_fn(x + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(6, 7)), rng.normal(size=(6, 7)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("reduce_fn", [pt.all, pt.any], ids=["all", "any"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_all_any(self, reduce_fn, axis): + rng = np.random.default_rng(3) + x = pt.matrix("x", dtype="bool") + y = pt.matrix("y", dtype="bool") + out = reduce_fn(x & y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv = rng.integers(0, 2, size=(4, 5)).astype(bool) + yv = rng.integers(0, 2, size=(4, 5)).astype(bool) + np.testing.assert_array_equal(fn(xv, yv), fn_u(xv, yv)) + + @pytest.mark.parametrize("dtype", ["int8", "int32", "uint8"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_acc_dtype_widening(self, dtype, axis): + """Sum of small int dtype accumulates in a wider acc_dtype.""" + rng = np.random.default_rng(4) + x = pt.matrix("x", dtype=dtype) + out = pt.sum(x + x, axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + info = np.iinfo(dtype) + xv = rng.integers(0, min(info.max // 2, 50), size=(40, 40)).astype(dtype) + np.testing.assert_array_equal(fn(xv), fn_u(xv)) + # Result must be the wide acc dtype, not overflow the input dtype + assert fn(xv).dtype == fn_u(xv).dtype + + def test_scalar_and_1d(self): + rng = np.random.default_rng(5) + x = pt.vector("x") + out = pt.sum(pt.exp(x)) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(17,)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_non_c_contiguous_input(self): + """Reduction over a transposed (non-C-contiguous) intermediate.""" + rng = np.random.default_rng(6) + x = pt.matrix("x") + out = pt.sum((x + 1.0).T, axis=0) + fn, fn_u = fused_and_unfused([x], out) + xv = rng.normal(size=(8, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_reduce_of_inplace_elemwise(self): + """Inplace must not survive on an output fused as a reduction. + + The dot output is a destroyable intermediate the inplace pass claims for the + Mul before fusion. If the fused op kept that destroy entry, the reduce + accumulator would alias the input buffer and skip the identity init, + folding stale input values into the result. + """ + rng = np.random.default_rng(7) + x, y, z = pt.matrix("x"), pt.matrix("y"), pt.matrix("z") + out = ((x @ y) * z).sum() + fn, fn_u = fused_and_unfused([x, y, z], out) + assert_reduce_fused(fn) + for node in fn.maker.fgraph.toposort(): + if isinstance(node.op, FusedElemwise): + assert not any( + out_idx in node.op.destroy_map + for out_idx, r in enumerate(node.op.reduced_outputs) + if r is not None + ) + xv, yv, zv = (rng.normal(size=(4, 4)) for _ in range(3)) + np.testing.assert_allclose(fn(xv, yv, zv), fn_u(xv, yv, zv), rtol=1e-10) + + def test_reduce_with_direct_use_keeps_inplace(self): + """Inplace survives on an output that is both reduced and used directly. + + The reduce-and-direct duplication keeps the original output materialized + (the CAReduce consumes a duplicate), so the inplace claimed on the dot + intermediate stays valid and must not be stripped. + """ + rng = np.random.default_rng(8) + x, y, z = pt.matrix("x"), pt.matrix("y"), pt.matrix("z") + w = (x @ y) * z + fn, fn_u = fused_and_unfused([x, y, z], [w.sum(), w]) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + reduced_idxs = { + i for i, r in enumerate(node.op.reduced_outputs) if r is not None + } + # The materialized output keeps its destroy entry; the reduced one has none + assert node.op.destroy_map + assert not set(node.op.destroy_map) & reduced_idxs + xv, yv, zv = (rng.normal(size=(4, 4)) for _ in range(3)) + for res, res_u in zip(fn(xv, yv, zv), fn_u(xv, yv, zv)): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + + +class TestReductionWithIndexing: + """Reductions composed with indexed reads in a single fused loop.""" + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_gather(self, axis): + rng = np.random.default_rng(10) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + out = pt.sum(x[idx] + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, yv, idxv), fn_u(xv, yv, idxv), rtol=1e-10) + + def test_max_gather(self): + rng = np.random.default_rng(11) + x = pt.matrix("x") + idx = pt.lvector("idx") + out = pt.max(pt.exp(x[idx]), axis=0) + fn, fn_u = fused_and_unfused([x, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + + def test_gather_scatter_and_reduce_mix(self): + """Gather + elemwise + scatter + reduce all fuse into one loop. + + Sibling fusion merges the two elemwise expressions into one multi-output + Composite; FuseElemwise then absorbs the indexed read, the indexed + write, and the reduction into a single FusedElemwise. + """ + rng = np.random.default_rng(12) + x = pt.matrix("x") + y = pt.matrix("y") + t = pt.matrix("t") + idx = pt.lvector("idx") + scattered = t[idx].inc(x[idx] * y) + reduced = pt.sum(x[idx] + y) + fn, fn_u = fused_and_unfused([x, y, t, idx], [scattered, reduced]) + nodes = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert len(nodes) == 1 + [node] = nodes + assert any(spec is not None for spec in node.op.indexed_inputs) + assert any(spec is not None for spec in node.op.indexed_outputs) + assert any(r is not None for r in node.op.reduced_outputs) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + tv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 8, size=4) + for res, res_u in zip( + fn(xv, yv, tv.copy(), idxv), fn_u(xv, yv, tv.copy(), idxv) + ): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + + +class TestReductionMultiOutput: + """Reduction plus direct use of the same Elemwise output (duplication).""" + + def test_sum_and_direct(self): + rng = np.random.default_rng(20) + x = pt.matrix("x") + y = pt.matrix("y") + f = pt.exp(x) + y + out = [pt.sum(f, axis=0), f] + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(4, 5)), rng.normal(size=(4, 5)) + r, ru = fn(xv, yv), fn_u(xv, yv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + def test_two_reductions_same_source(self): + """sum and max of the same elemwise output (two CAReduce clients).""" + rng = np.random.default_rng(21) + x = pt.matrix("x") + f = pt.exp(x) + out = [pt.sum(f, axis=0), pt.max(f, axis=0)] + fn, fn_u = fused_and_unfused([x], out) + xv = rng.normal(size=(4, 5)) + r, ru = fn(xv), fn_u(xv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + +class TestReductionPythonMode: + """The fused op evaluates correctly outside JIT (OpFromGraph.perform).""" + + def test_perform_matches(self): + rng = np.random.default_rng(30) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + fn = function( + [x, y, idx], pt.sum(x[idx] + y, axis=1), mode=NUMBA_MODE, trust_input=True + ) + node = next( + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ) + # Re-apply the exact fused op and evaluate it via OpFromGraph.perform + fresh = [inp.type() for inp in node.inputs] + perform_fn = function( + fresh, node.op(*fresh, return_list=True), mode="FAST_COMPILE" + ) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose( + perform_fn(xv, yv, idxv)[0], np.sum(xv[idxv] + yv, axis=1), rtol=1e-10 + ) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7ac8f4c65b..dac8690abd 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -249,7 +249,11 @@ def _raise_on_opt_error(self): "add_mul_fusion", "inplace", ], - exclude=["cxx_only", "BlasOpt"], + # Exclude both careduce-fusion paths so reductions stay unfused here: + # cxx_only covers local_careduce_fusion (C backend); the indexed/reduce + # fusion is the Numba-specific equivalent. This class tests the generic + # Composite fusion, independent of the active backend. + exclude=["cxx_only", "BlasOpt", "fuse_indexed_into_elemwise"], ) mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 7042ddf598..a8fc733667 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2827,7 +2827,13 @@ class TestLocalSumProd: """Test sum/prod rewrites.""" def setup_method(self): - self.mode = get_default_mode().including("canonicalize", "specialize") + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize") + .excluding("fuse_indexed_into_elemwise") + ) def test_local_sum_prod_of_scalar_mul(self): # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and @@ -3336,7 +3342,9 @@ def test_local_prod_of_div(self): c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX) d_val = np.asarray(rng.standard_normal(), config.floatX) - default_mode = get_default_mode() + # Exclude the Numba reduction fusion so the outer reduction op stays + # visible in the toposort for the structural assertions below. + default_mode = get_default_mode().excluding("fuse_indexed_into_elemwise") # `FusionOptimizer` is included to make sure that `expected_outer_operator` # remains the same for all rewrite modes. mode_with_rewrite = default_mode.including( @@ -3393,8 +3401,12 @@ def test_local_prod_of_div(self): class TestLocalReduce: def setup_method(self): - self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize" + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize", "uncanonicalize") + .excluding("fuse_indexed_into_elemwise") ) def test_local_reduce_broadcast_all_0(self): diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index ecb671e3ce..98b7e71219 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -18,6 +18,7 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, exp, sqr +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.rewriting.subtensor import ( _slice_to_arange, local_add_of_sparse_write, @@ -490,11 +491,9 @@ def test_local_useless_subtensor_6(self, idx, res): else: # Arange-with-offset indices get rewritten to a Subtensor slice; # other advanced indices stay as AdvancedSubtensor1 (or get - # absorbed into IndexedElemwise by FuseIndexedElemwise). - from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise - + # absorbed into FusedElemwise by FuseElemwise). assert any( - isinstance(node.op, AdvancedSubtensor1 | Subtensor | IndexedElemwise) + isinstance(node.op, AdvancedSubtensor1 | Subtensor | FusedElemwise) for node in prog ) diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index 9d5011b6db..e842703d45 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -23,8 +23,12 @@ class TestMinMax: def setup_method(self): - self.mode = pytensor.compile.mode.get_default_mode().including( - "canonicalize", "fast_run" + # Exclude the Numba reduction fusion so the Max/Min CAReduce nodes stay + # visible in the toposort for the structural assertions below. + self.mode = ( + pytensor.compile.mode.get_default_mode() + .including("canonicalize", "fast_run") + .excluding("fuse_indexed_into_elemwise") ) def test_optimization_min(self): From 91b69c0264b64b8c05dd25ab790b76d827250edc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 17:06:10 +0200 Subject: [PATCH 3/4] Fuse multiple reductions of the same Elemwise output An output consumed by several eligible CAReduces previously disqualified itself entirely (the detection required exactly one reduce client). Peel one extra reduction per rewrite pass onto a duplicate output until each reduction has its own copy, so e.g. [sum(f), max(f), prod(f), f] becomes a single FusedElemwise with three fused reductions. --- pytensor/tensor/rewriting/fused_elemwise.py | 33 ++++++++++++++++++--- tests/link/numba/test_fused_elemwise.py | 24 ++++++++++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/fused_elemwise.py b/pytensor/tensor/rewriting/fused_elemwise.py index f97f3e8c28..d718c830ea 100644 --- a/pytensor/tensor/rewriting/fused_elemwise.py +++ b/pytensor/tensor/rewriting/fused_elemwise.py @@ -433,10 +433,11 @@ def _extract_idx_axis_pairs(node, *, write=False): @staticmethod def _duplicate_multi_client_outputs(node, multi_client_outs): - """Add duplicate outputs for Elemwise results that have both write and non-write consumers. + """Add duplicate outputs for Elemwise results whose consumers must be split. - Returns ``(new_node, dup_map)`` where *dup_map* maps each original - output index to its duplicate position. + Used when an output is both written/reduced and consumed directly, or is + consumed by several reductions. Returns ``(new_node, dup_map)`` where + *dup_map* maps each original output index to its duplicate position. """ scalar_op = node.op.scalar_op if isinstance(scalar_op, Composite): @@ -639,6 +640,9 @@ def apply(self, fgraph): # output can't be both an indexed write and a reduction). Outputs # with extra (non-reduce) clients are handled by duplication below. reduced_outputs = {} # out_idx -> car_node + extra_reduction = ( + None # (out_idx, car_node) when reductions share an output + ) for out_idx, out in enumerate(node.outputs): if out_idx in write_targets: continue @@ -648,10 +652,31 @@ def apply(self, fgraph): if isinstance(c.op, CAReduce) and isinstance(c.op.scalar_op, _REDUCE_SCALAR_OPS) ] - if len(car_clients) != 1: + if not car_clients: continue + if len(car_clients) > 1: + extra_reduction = (out_idx, car_clients[1]) + break reduced_outputs[out_idx] = car_clients[0] + # An output consumed by several eligible reductions: peel one reduction + # per pass onto a duplicate output, until each reduction has its own copy + if extra_reduction is not None: + out_idx, car_node = extra_reduction + new_node, dup_map = self._duplicate_multi_client_outputs( + node, {out_idx} + ) + replacements = list( + zip(node.outputs, new_node.outputs[: len(node.outputs)]) + ) + new_reduced = car_node.op(new_node.outputs[dup_map[out_idx]]) + replacements.append((car_node.outputs[0], new_reduced)) + fgraph.replace_all( + replacements, reason="fuse_elemwise_split_multi_reductions" + ) + worklist.append(new_node) + continue + if not idx_groups and not reduced_outputs: continue diff --git a/tests/link/numba/test_fused_elemwise.py b/tests/link/numba/test_fused_elemwise.py index 74204e3200..d377c01e89 100644 --- a/tests/link/numba/test_fused_elemwise.py +++ b/tests/link/numba/test_fused_elemwise.py @@ -5,7 +5,7 @@ import pytensor.tensor as pt from pytensor import Mode, function, get_mode -from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.elemwise import CAReduce, Elemwise from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -922,11 +922,33 @@ def test_two_reductions_same_source(self): f = pt.exp(x) out = [pt.sum(f, axis=0), pt.max(f, axis=0)] fn, fn_u = fused_and_unfused([x], out) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert sum(r is not None for r in node.op.reduced_outputs) == 2 + assert not any(isinstance(n.op, CAReduce) for n in fn.maker.fgraph.toposort()) xv = rng.normal(size=(4, 5)) r, ru = fn(xv), fn_u(xv) np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + def test_three_reductions_and_direct_use(self): + """Three reductions plus a direct consumer of one shared output.""" + rng = np.random.default_rng(22) + x = pt.matrix("x") + y = pt.matrix("y") + f = x * y + out = [pt.sum(f), pt.max(f, axis=0), pt.prod(f, axis=1), f] + fn, fn_u = fused_and_unfused([x, y], out) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert sum(r is not None for r in node.op.reduced_outputs) == 3 + assert not any(isinstance(n.op, CAReduce) for n in fn.maker.fgraph.toposort()) + xv, yv = rng.normal(size=(4, 5)), rng.normal(size=(4, 5)) + for res, res_u in zip(fn(xv, yv), fn_u(xv, yv)): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + class TestReductionPythonMode: """The fused op evaluates correctly outside JIT (OpFromGraph.perform).""" From 5a5637c931e659eee2a8fc936d6f889d0d4ae534 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 17:09:18 +0200 Subject: [PATCH 4/4] Fuse reductions of bare indexed reads sum(x[idx]) had no Elemwise for FuseElemwise to anchor on, so the gather materialized and the reduction stayed external. A new pre-rewrite wraps such reductions in an identity Elemwise (covering the bare AdvancedSubtensor1, axis-swap DimShuffle and flattened-ND-index Reshape forms), letting gather, identity and reduction collapse into one fused loop. --- pytensor/tensor/rewriting/fused_elemwise.py | 55 ++++++++++++++++++++- tests/link/numba/test_fused_elemwise.py | 43 ++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/fused_elemwise.py b/pytensor/tensor/rewriting/fused_elemwise.py index d718c830ea..beafde64e2 100644 --- a/pytensor/tensor/rewriting/fused_elemwise.py +++ b/pytensor/tensor/rewriting/fused_elemwise.py @@ -14,7 +14,19 @@ from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.utils import InconsistencyError from pytensor.printing import op_debug_information -from pytensor.scalar.basic import AND, OR, XOR, Add, Composite, Maximum, Minimum, Mul +from pytensor.scalar.basic import ( + AND, + OR, + XOR, + Add, + Composite, + Maximum, + Minimum, + Mul, +) +from pytensor.scalar.basic import ( + identity as scalar_identity, +) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.shape import Reshape, shape_padright @@ -175,6 +187,40 @@ def undo_take_reshape_for_fusion(fgraph, node): return [new_out] +@node_rewriter([CAReduce]) +def wrap_reduced_gather_in_elemwise(fgraph, node): + """Wrap a reduction of a bare indexed read in an identity Elemwise. + + ``FuseElemwise`` anchors on Elemwise nodes, so ``sum(x[idx])`` — with no + elementwise computation between the gather and the reduction — would leave + the gather materialized. ``Elemwise(identity)`` gives the fusion an anchor: + gather, identity and reduction collapse into one fused loop (the identity is + a no-op inside the loop body). Runs before the ``undo_take`` rewrites, which + require the indexed read to feed an Elemwise. + """ + if not isinstance(node.op.scalar_op, _REDUCE_SCALAR_OPS): + return None + [inp] = node.inputs + if inp.owner is None or len(fgraph.clients[inp]) != 1: + return None + + owner_op = inp.owner.op + if isinstance(owner_op, AdvancedSubtensor): + is_gather = True + elif isinstance(owner_op, Reshape): + # The flattened-ND-index form undo_take_reshape_for_fusion normalizes + is_gather = ( + _unwrap_axis_swapped_subtensor1(fgraph, inp.owner.inputs[0]) is not None + ) + else: + # Bare AdvancedSubtensor1 or the axis-swap DimShuffle-wrapped form + is_gather = _unwrap_axis_swapped_subtensor1(fgraph, inp) is not None + if not is_gather: + return None + + return [node.op(Elemwise(scalar_identity)(inp))] + + fused_elemwise_optdb = SequenceDB() optdb.register( # Predates the FusedElemwise rename; kept so existing .including/.excluding @@ -190,6 +236,13 @@ def undo_take_reshape_for_fusion(fgraph, node): position=100, ) +fused_elemwise_optdb.register( + "wrap_reduced_gather_in_elemwise", + dfs_rewriter(wrap_reduced_gather_in_elemwise), + "numba", + position=-0.5, +) + fused_elemwise_optdb.register( "undo_take_dimshuffle_for_fusion", dfs_rewriter(undo_take_dimshuffle_for_fusion), diff --git a/tests/link/numba/test_fused_elemwise.py b/tests/link/numba/test_fused_elemwise.py index d377c01e89..4bf1e83a8d 100644 --- a/tests/link/numba/test_fused_elemwise.py +++ b/tests/link/numba/test_fused_elemwise.py @@ -866,6 +866,49 @@ def test_max_gather(self): idxv = rng.integers(0, 8, size=4) np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + @pytest.mark.parametrize( + "make_out", + [ + lambda x, idx: pt.sum(x[idx]), + lambda x, idx: pt.max(x[idx], axis=0), + lambda x, idx: pt.sum(x[:, idx]), + ], + ids=["sum_all", "max_axis0", "sum_axis1_gather"], + ) + def test_reduce_of_bare_gather(self, make_out): + """Reduction of an indexed read with no elemwise in between. + + wrap_reduced_gather_in_elemwise inserts an identity Elemwise so the + gather and the reduction still collapse into one fused loop. + """ + rng = np.random.default_rng(13) + x = pt.matrix("x") + idx = pt.lvector("idx") + fn, fn_u = fused_and_unfused([x, idx], make_out(x, idx)) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert any(spec is not None for spec in node.op.indexed_inputs) + xv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 5, size=4) + np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + + def test_reduce_of_bare_nd_gather(self): + """Reduction of an ND-index read (Reshape-flattened form).""" + rng = np.random.default_rng(14) + x = pt.matrix("x") + mat_idx = pt.lmatrix("mat_idx") + fn, fn_u = fused_and_unfused([x, mat_idx], pt.sum(x[mat_idx])) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert any(spec is not None for spec in node.op.indexed_inputs) + xv = rng.normal(size=(8, 5)) + mv = rng.integers(0, 8, size=(3, 2)) + np.testing.assert_allclose(fn(xv, mv), fn_u(xv, mv), rtol=1e-10) + def test_gather_scatter_and_reduce_mix(self): """Gather + elemwise + scatter + reduce all fuse into one loop.