diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index af4aa4aee3..49fc4b43f6 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -13,6 +13,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -96,6 +97,7 @@ def impl(*inputs_and_core_shapes): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1383ee1df1..06a3a7f419 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -29,6 +29,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -50,7 +51,7 @@ from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise @singledispatch @@ -140,6 +141,38 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr): ] +# Scalar reduce ops that ``accumulate_into_slice`` (and thus fused reductions / +# indexed inc) supports, paired with the in-place form used on the output slice. +# Augmented assignment is used where Numba supports it directly on 0-d arrays +# (``out`` is the core output slice, which is 0-d for scalar cores); Maximum and +# Minimum lack an augmented operator so they use a ufunc + ellipsis-set instead. +_SLICE_ACCUMULATE = { + Add: "{out} += {inner}", + Mul: "{out} *= {inner}", + AND: "{out} &= {inner}", + OR: "{out} |= {inner}", + XOR: "{out} ^= {inner}", + Maximum: "{out}[...] = np.maximum({out}, {inner})", + Minimum: "{out}[...] = np.minimum({out}, {inner})", +} + + +def accumulate_into_slice(scalar_op, out: str, inner: str) -> list[str]: + """In-place accumulation lines for a (possibly 0-d) output slice ``out``. + + Unlike ``scalar_in_place_fn`` (which indexes ``res[idx]`` and breaks on 0-d + cores), this operates on the whole slice so it is valid for both scalar + (0-d) and array cores. Used by both fused reductions and indexed ``inc``. + """ + try: + template = _SLICE_ACCUMULATE[type(scalar_op)] + except KeyError: + raise NotImplementedError( + f"No fused-reduction accumulation for scalar op {scalar_op}" + ) + return [template.format(out=out, inner=inner)] + + @intrinsic def _address_as_void_pointer(typingctx, src): """Returns a void pointer from a given memory address.""" @@ -695,6 +728,7 @@ def impl(*inputs): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl @@ -715,8 +749,88 @@ def impl(*inputs): return elemwise, elemwise_key -@register_funcify_and_cache_key(IndexedElemwise) -def numba_funcify_IndexedElemwise(op, node, **kwargs): +def _reduce_identity(identity, acc_dtype): + """Coerce a reduction identity to a concrete numpy scalar in ``acc_dtype``. + + Non-finite identities (``±inf`` for max/min) become the dtype's bounds for + integer accumulators, mirroring ``create_multiaxis_reducer``. + """ + acc_dtype = np.dtype(acc_dtype) + if acc_dtype.kind in "ui" and not np.isfinite(identity): + identity = ( + np.iinfo(acc_dtype).max + if np.isposinf(identity) + else np.iinfo(acc_dtype).min + ) + return acc_dtype.type(identity) + + +def _build_reduce_impl_src(nout, post_specs): + """Build the ``@overload`` impl source for a fused op with reduction outputs. + + Calls ``_vectorized`` (which produces a keepdims, size-1-on-reduced-axes + buffer in ``acc_dtype``) then, per reduction output, squeezes the reduced + axes back out and casts to the true output dtype. ``post_specs[i]`` is + ``None`` for a passthrough output, else ``(kept_axes, out_dtype, cast_needed)``. + """ + code: list[str | CODE_TOKEN] = [ + "def fused_elemwise_impl(*outer_inputs):", + CODE_TOKEN.INDENT, + "raw = _vectorized(", + CODE_TOKEN.INDENT, + "core_op_fn,", + "input_bc_patterns_enc,", + "output_bc_patterns_enc,", + "output_dtypes_enc,", + "inplace_pattern_enc,", + "True,", + "(),", + "outer_inputs,", + "core_output_shapes,", + "NO_SIZE,", + "indexed_inputs_enc,", + "indexed_outputs_enc,", + "reduce_outputs_enc,", + CODE_TOKEN.DEDENT, + ")", + ] + + def src_access(i): + return "raw" if nout == 1 else f"raw[{i}]" + + out_syms = [] + for i in range(nout): + sym = f"o{i}" + out_syms.append(sym) + src = src_access(i) + spec = post_specs[i] + if spec is None: + code.append(f"{sym} = {src}") + continue + kept_axes, out_dtype, cast_needed = spec + np_dtype = "bool_" if out_dtype == "bool" else out_dtype + if not kept_axes: + # Full reduction → 0-d array of the single accumulated value. + code.append(f"{sym} = np.array({src}.ravel()[0], dtype=np.{np_dtype})") + else: + shape_expr = create_tuple_string( + tuple(f"{src}.shape[{k}]" for k in kept_axes) + ) + expr = f"{src}.reshape({shape_expr})" + if cast_needed: + expr = f"{expr}.astype(np.{np_dtype})" + code.append(f"{sym} = {expr}") + + if nout == 1: + code.append(f"return {out_syms[0]}") + else: + code.append(f"return {create_tuple_string(tuple(out_syms))}") + code.append(CODE_TOKEN.DEDENT) + return build_source_code(code) + + +@register_funcify_and_cache_key(FusedElemwise) +def numba_funcify_FusedElemwise(op, node, **kwargs): """Generate fused Elemwise Numba code with indexed reads and updates. Reads indexed_inputs/indexed_outputs specs stored on the Op by the @@ -739,6 +853,7 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): n_indices = len(indexed_inputs) nin_elemwise = len(elemwise_node.inputs) nout = len(elemwise_node.outputs) + reduced_outputs = op.reduced_outputs or ((None,) * nout) inc_outputs = frozenset( out_idx @@ -747,14 +862,58 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): for out_idx in entry[0] if entry[2] == "inc" ) + reduced_idxs = frozenset(i for i, r in enumerate(reduced_outputs) if r is not None) + assert not (inc_outputs & reduced_idxs), ( + "An output cannot be both an indexed write and a reduction" + ) + + # accum_fns bake per-output in-place accumulation into store_core_outputs: + # indexed `inc` writes accumulate with +=; reductions use their scalar op + # (Add/Mul/Maximum/...). Both go through accumulate_into_slice so they are + # valid on 0-d (scalar core) slices. Disjoint output indices. + accum_fns: dict = {} + for out_idx in inc_outputs: + accum_fns[out_idx] = lambda out, inner: [f"{out} += {inner}"] + for i in reduced_idxs: + accum_fns[i] = lambda out, inner, _op=reduced_outputs[i][0]: ( + accumulate_into_slice(_op, out, inner) + ) core_op_fn = store_core_outputs( - scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + scalar_op_fn, nin=nin_elemwise, nout=nout, accum_fns=accum_fns ) input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) - output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) - output_dtypes = tuple(out.type.dtype for out in node.outputs) + + # Reduction outputs keep their reduced axes as bc=True (size-1, keepdims) and + # are accumulated in acc_dtype; post_specs squeezes + casts them back to the + # true CAReduce output below. + output_bc_patterns_list = [] + output_dtypes_list = [] + reduce_identities = [] + post_specs: list = [] + has_reductions = bool(reduced_idxs) + for i, inner_out in enumerate(elemwise_node.outputs): + spec = reduced_outputs[i] + if spec is None: + output_bc_patterns_list.append(inner_out.type.broadcastable) + output_dtypes_list.append(node.outputs[i].type.dtype) + post_specs.append(None) + continue + _reduce_op, axes, identity, acc_dtype = spec + bc = list(inner_out.type.broadcastable) + for ax in axes: + bc[ax] = True + output_bc_patterns_list.append(tuple(bc)) + output_dtypes_list.append(str(np.dtype(acc_dtype))) + reduce_identities.append((i, _reduce_identity(identity, acc_dtype))) + kept_axes = tuple(d for d in range(len(bc)) if d not in axes) + out_dtype = node.outputs[i].type.dtype + cast_needed = np.dtype(acc_dtype) != np.dtype(out_dtype) + post_specs.append((kept_axes, out_dtype, cast_needed)) + + output_bc_patterns = tuple(output_bc_patterns_list) + output_dtypes = tuple(output_dtypes_list) inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) core_output_shapes = tuple(() for _ in range(nout)) @@ -768,52 +927,88 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): inplace_pattern_enc = encode_literals(inplace_pattern) indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable)) indexed_outputs_enc = encode_literals(indexed_outputs) + reduce_outputs_enc = encode_literals(tuple(reduce_identities)) + + def fused_elemwise_fn(*outer_inputs): + # Python-mode fallback (e.g. Numba's ``eval_python_only`` path, which + # runs the funcified function without JIT). Evaluate the inner fgraph + # faithfully, exactly like ``OpFromGraph.perform`` — so the fused op + # still produces valid results when executed directly in Python. The + # JIT path never reaches this body: the ``@overload`` below supplies the + # vectorized loop implementation. + res = op.fn(*outer_inputs) + return res[0] if len(res) == 1 else tuple(res) + + if not has_reductions: + + @overload(fused_elemwise_fn, jit_options=_jit_options) + def ov_fused_elemwise_fn(*outer_inputs): + def impl(*outer_inputs): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + outer_inputs, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + reduce_outputs_enc, + ) - def indexed_elemwise_fn(*outer_inputs): - raise NotImplementedError( - "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." + return impl + else: + impl_src = _build_reduce_impl_src(nout, post_specs) + impl_fn = compile_numba_function_src( + impl_src, + "fused_elemwise_impl", + { + **globals(), + "core_op_fn": core_op_fn, + "input_bc_patterns_enc": input_bc_patterns_enc, + "output_bc_patterns_enc": output_bc_patterns_enc, + "output_dtypes_enc": output_dtypes_enc, + "inplace_pattern_enc": inplace_pattern_enc, + "core_output_shapes": core_output_shapes, + "indexed_inputs_enc": indexed_inputs_enc, + "indexed_outputs_enc": indexed_outputs_enc, + "reduce_outputs_enc": reduce_outputs_enc, + }, ) - @overload(indexed_elemwise_fn, jit_options=_jit_options) - def ov_indexed_elemwise_fn(*outer_inputs): - def impl(*outer_inputs): - return _vectorized( - core_op_fn, - input_bc_patterns_enc, - output_bc_patterns_enc, - output_dtypes_enc, - inplace_pattern_enc, - True, # allow_core_scalar - (), # constant_inputs - outer_inputs, - core_output_shapes, - NO_SIZE, - indexed_inputs_enc, - indexed_outputs_enc, - ) - - return impl + @overload(fused_elemwise_fn, jit_options=_jit_options) + def ov_fused_elemwise_fn(*outer_inputs): + return impl_fn - cache_version = 2 + cache_version = 3 if scalar_cache_key is None: key = None else: + reduced_key = tuple( + (type(r[0]).__name__, r[1], str(np.dtype(r[3]))) if r is not None else None + for r in reduced_outputs + ) key = str( ( type(op), - "IndexedElemwise", + "FusedElemwise", cache_version, inplace_pattern, input_bc_patterns, indexed_inputs, idx_broadcastable, indexed_outputs, + reduced_key, scalar_cache_key, ) ) key = sha256(key.encode()).hexdigest() - return indexed_elemwise_fn, key + return fused_elemwise_fn, key @register_funcify_and_cache_key(CAReduce) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 28ade0bf3a..2f7f8b3700 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -23,6 +23,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -490,6 +491,7 @@ def impl(core_shape, rng, size, *dist_params): else numba_ndarray.to_fixed_tuple(size, size_len), NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return rng, draws diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index dfd12ed1b3..23fe491186 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -3,7 +3,6 @@ import base64 import pickle from collections.abc import Callable, Sequence -from textwrap import indent from typing import Any import numba @@ -17,6 +16,7 @@ from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.string_codegen import CODE_TOKEN, build_source_code def encode_literals(literals: Sequence) -> str: @@ -24,7 +24,11 @@ def encode_literals(literals: Sequence) -> str: def store_core_outputs( - core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() + core_op_fn: Callable, + nin: int, + nout: int, + accum_fns: dict[int, Callable[[str, str], Sequence[str | CODE_TOKEN]]] + | None = None, ) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @@ -35,14 +39,20 @@ def store_core_outputs( def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) o0[...] = to0 # direct outputs - o1 += to1 # inc outputs (in-place add works for 0d and Nd) + o1[...] += to1 # accumulating outputs (reduce / indexed-inc) ... - ``inc_outputs`` lists output indices that use ``+=`` instead of ``=``. + ``accum_fns`` maps an output index to a callable ``(out_sym, inner_sym) -> + lines`` producing the in-place accumulation code for that output (e.g. + ``["o1[...] += t1"]`` for a sum reduction, or the multi-line conditional for + a max reduction). Outputs absent from ``accum_fns`` are stored with ``=``. + Both reductions and indexed ``inc`` writes go through this mechanism. """ if getattr(core_op_fn, "handles_out", False): return core_op_fn + accum_fns = accum_fns or {} + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] @@ -50,19 +60,22 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): inp_signature = ", ".join(inputs) out_signature = ", ".join(outputs) inner_out_signature = ", ".join(inner_outputs) - store_outputs = "\n".join( - f"{output} += {inner_output}" - if i in inc_outputs - else f"{output}[...] = {inner_output}" - for i, (output, inner_output) in enumerate( - zip(outputs, inner_outputs, strict=True) - ) - ) - func_src = f""" -def store_core_outputs({inp_signature}, {out_signature}): - {inner_out_signature} = core_op_fn({inp_signature}) -{indent(store_outputs, " " * 4)} -""" + + code: list[str | CODE_TOKEN] = [ + f"def store_core_outputs({inp_signature}, {out_signature}):", + CODE_TOKEN.INDENT, + f"{inner_out_signature} = core_op_fn({inp_signature})", + ] + for i, (output, inner_output) in enumerate( + zip(outputs, inner_outputs, strict=True) + ): + if i in accum_fns: + code.extend(accum_fns[i](output, inner_output)) + else: + code.append(f"{output}[...] = {inner_output}") + code.append(CODE_TOKEN.DEDENT) + + func_src = build_source_code(code) global_env = {"core_op_fn": core_op_fn} func = compile_numba_function_src( @@ -249,6 +262,7 @@ def _codegen_return_outputs( NO_INDEXED_INPUTS = encode_literals(((), ())) NO_INDEXED_OUTPUTS = encode_literals(()) +NO_REDUCE_OUTPUTS = encode_literals(()) NO_SIZE = None @@ -355,11 +369,17 @@ def make_outputs( input_types: tuple[Any, ...], output_core_shapes: tuple, update_outputs: dict | None = None, + reduce_identities: dict | None = None, ) -> tuple[list[ir.Value], list[types.Array]]: """Allocate output arrays for vectorized loop. ``update_outputs`` maps ``{output_idx: (array, array_type)}`` for outputs that reuse an indexed-write target buffer instead of being freshly allocated. + + ``reduce_identities`` maps ``{output_idx: identity_value}`` for reduction + outputs. Such outputs are freshly allocated (size 1 on the reduced axes, via + their ``bc=True`` pattern) and pre-filled with the reduction identity so the + accumulating store in the loop reduces into them correctly. """ output_arrays = [] output_arry_types = [] @@ -389,6 +409,22 @@ def make_outputs( ] shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) + if reduce_identities is not None and i in reduce_identities: + # Pre-fill the freshly allocated (C-contiguous) buffer with the + # reduction identity. A flat scan over every element is valid + # regardless of which axes are reduced, and seeds each kept-axis + # accumulator cell (size-1 reduced axes included). + nitems = ir.IntType(64)(1) + for dim_len in shape: + nitems = builder.mul(nitems, dim_len) + ident = ctx.get_constant(dtype, reduce_identities[i]) + # bool is an i1 value but stored as i8 in arrays; widen to the + # buffer's element type so the store types match. + elem_ty = array.data.type.pointee + if ident.type != elem_ty: + ident = builder.zext(ident, elem_ty) + with cgutils.for_range(builder, nitems) as loop: + builder.store(ident, builder.gep(array.data, [loop.index])) output_arrays.append(array) # If there is no inplace operation, we know that all output arrays @@ -423,8 +459,6 @@ def make_loop_call( ): safe = (False, False) - n_outputs = len(outputs) - # TODO I think this is better than the noalias attribute # for the input, but self_ref isn't supported in a released # llvmlite version yet @@ -449,21 +483,13 @@ def _wrap_negative_index(idx_val, dim_size, signed): wrapped = builder.add(idx_val, dim_size) return builder.select(is_neg, wrapped, idx_val) - # Setup loops and initialize accumulators for outputs - # This part corresponds to opening the loops + # Open one loop per iteration dimension. Reduction outputs need no special + # setup here: they carry ``bc=True`` on the reduced axes, so the write_idx + # logic below points every iteration over a reduced axis at memory index 0 + # (the same cell), and the accumulating store reduces into it. loop_stack = [] loops = [] - output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs - for dim, length in enumerate(iter_shape): - # Find outputs that only have accumulations left - for out in range(n_outputs): - if output_accumulator[out][0] is not None: - continue - if all(output_bc[out][dim:]): - value = outputs[out][0].type.pointee(0) - accu = cgutils.alloca_once_value(builder, value) - output_accumulator[out] = (accu, dim) - + for length in iter_shape: loop = cgutils.for_range(builder, length) loop_stack.append(loop) loops.append(loop.__enter__()) @@ -736,6 +762,7 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ): """Vectorized intrinsic with optional indirect indexing for reads and writes. @@ -751,7 +778,12 @@ def _vectorized( ``((out_0, out_1), mode)`` means that index updates outputs out_0 and out_1 with *mode* ``"set"`` or ``"inc"``. - For non-indexed calls, both are ``()``. + ``reduce_outputs`` lists ``(output_idx, identity)`` pairs for reduction + outputs. Such an output carries ``bc=True`` on its reduced axes; the buffer + is allocated size 1 there, pre-filled with ``identity``, and the per-iteration + store (baked into ``core_func`` via ``store_core_outputs``) accumulates into it. + + For non-indexed/non-reducing calls, these are ``()``. """ arg_types = [ core_func, @@ -766,12 +798,14 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ] input_bc_patterns = _decode_literal(input_bc_patterns, "input_bc_patterns") output_bc_patterns = _decode_literal(output_bc_patterns, "output_bc_patterns") output_dtypes = _decode_literal(output_dtypes, "output_dtypes") inplace_pattern = _decode_literal(inplace_pattern, "inplace_pattern") + reduce_identities = dict(_decode_literal(reduce_outputs, "reduce_outputs")) indexed_inputs, idx_broadcastable = _decode_literal( indexed_inputs, "indexed_inputs" ) @@ -919,6 +953,7 @@ def codegen(ctx, builder, sig, args): size, _, _, + _, ] = args constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) @@ -1053,6 +1088,7 @@ def codegen(ctx, builder, sig, args): source_input_types, output_core_shapes, update_outputs=update_outputs_dict, + reduce_identities=reduce_identities, ) core_signature = typingctx.resolve_function_type( diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/fused_elemwise.py similarity index 68% rename from pytensor/tensor/rewriting/indexed_elemwise.py rename to pytensor/tensor/rewriting/fused_elemwise.py index b3ff3828b3..beafde64e2 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/fused_elemwise.py @@ -1,9 +1,9 @@ -"""Fuse indexed reads and updates into Elemwise iteration loops. +"""Fuse indexed reads/writes and reductions into Elemwise iteration loops. -Introduces ``IndexedElemwise``, an ``OpFromGraph`` that wraps -``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` subgraphs -so the Numba backend can generate a single loop with indirect indexing, -eliminating materialised intermediate arrays. +Introduces ``FusedElemwise``, an ``OpFromGraph`` that wraps +``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` / ``CAReduce`` +subgraphs so the Numba backend can generate a single loop with indirect +indexing and inline accumulation, eliminating materialised intermediate arrays. """ from pytensor.compile import optdb @@ -14,8 +14,20 @@ from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.utils import InconsistencyError from pytensor.printing import op_debug_information -from pytensor.scalar.basic import Composite -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.scalar.basic import ( + AND, + OR, + XOR, + Add, + Composite, + Maximum, + Minimum, + Mul, +) +from pytensor.scalar.basic import ( + identity as scalar_identity, +) +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( @@ -27,6 +39,12 @@ from pytensor.tensor.variable import TensorVariable +# CAReduce scalar ops whose reduction the Numba backend can fuse into the loop +# (those for which the codegen has an in-place accumulation: see +# ``accumulate_into_slice`` in the numba elemwise dispatch). +_REDUCE_SCALAR_OPS = (Add, Mul, Maximum, Minimum, AND, OR, XOR) + + def _view_root(view_i, var): """Follow the destroy-handler view chain to the underlying buffer. @@ -102,7 +120,7 @@ def undo_take_dimshuffle_for_fusion(fgraph, node): The ``local_replace_AdvancedSubtensor`` specialize rewrite converts ``x[:, idx]`` into ``x.T[idx].T`` (axis-swap + AdvancedSubtensor1 + axis-swap). This rewrite undoes that when the result feeds a single - Elemwise, so ``FuseIndexedElemwise`` can absorb the indexing directly + Elemwise, so ``FuseElemwise`` can absorb the indexing directly on the correct axis. See also ``undo_take_reshape_for_fusion`` which handles the analogous @@ -135,7 +153,7 @@ def undo_take_reshape_for_fusion(fgraph, node): ``transform_take`` rewrites ``x[mat_idx]`` (ND integer index) into ``AdvancedSubtensor1(x, mat_idx.ravel()).reshape(mat_idx.shape + ...)``, possibly with DimShuffle axis-swaps for non-zero axes. This rewrite - undoes that so ``FuseIndexedElemwise`` can absorb the ND index directly. + undoes that so ``FuseElemwise`` can absorb the ND index directly. """ [reshape_out] = node.outputs @@ -169,10 +187,46 @@ def undo_take_reshape_for_fusion(fgraph, node): return [new_out] -indexed_elemwise_optdb = SequenceDB() +@node_rewriter([CAReduce]) +def wrap_reduced_gather_in_elemwise(fgraph, node): + """Wrap a reduction of a bare indexed read in an identity Elemwise. + + ``FuseElemwise`` anchors on Elemwise nodes, so ``sum(x[idx])`` — with no + elementwise computation between the gather and the reduction — would leave + the gather materialized. ``Elemwise(identity)`` gives the fusion an anchor: + gather, identity and reduction collapse into one fused loop (the identity is + a no-op inside the loop body). Runs before the ``undo_take`` rewrites, which + require the indexed read to feed an Elemwise. + """ + if not isinstance(node.op.scalar_op, _REDUCE_SCALAR_OPS): + return None + [inp] = node.inputs + if inp.owner is None or len(fgraph.clients[inp]) != 1: + return None + + owner_op = inp.owner.op + if isinstance(owner_op, AdvancedSubtensor): + is_gather = True + elif isinstance(owner_op, Reshape): + # The flattened-ND-index form undo_take_reshape_for_fusion normalizes + is_gather = ( + _unwrap_axis_swapped_subtensor1(fgraph, inp.owner.inputs[0]) is not None + ) + else: + # Bare AdvancedSubtensor1 or the axis-swap DimShuffle-wrapped form + is_gather = _unwrap_axis_swapped_subtensor1(fgraph, inp) is not None + if not is_gather: + return None + + return [node.op(Elemwise(scalar_identity)(inp))] + + +fused_elemwise_optdb = SequenceDB() optdb.register( + # Predates the FusedElemwise rename; kept so existing .including/.excluding + # calls targeting this name keep working. "fuse_indexed_into_elemwise", - indexed_elemwise_optdb, + fused_elemwise_optdb, "numba", # symbolic_op_recognition is excluded from OpFromGraph inner-graph # compilation, preventing recursive fusion. @@ -182,14 +236,21 @@ def undo_take_reshape_for_fusion(fgraph, node): position=100, ) -indexed_elemwise_optdb.register( +fused_elemwise_optdb.register( + "wrap_reduced_gather_in_elemwise", + dfs_rewriter(wrap_reduced_gather_in_elemwise), + "numba", + position=-0.5, +) + +fused_elemwise_optdb.register( "undo_take_dimshuffle_for_fusion", dfs_rewriter(undo_take_dimshuffle_for_fusion), "numba", position=0, ) -indexed_elemwise_optdb.register( +fused_elemwise_optdb.register( "undo_take_reshape_for_fusion", dfs_rewriter(undo_take_reshape_for_fusion), "numba", @@ -197,7 +258,7 @@ def undo_take_reshape_for_fusion(fgraph, node): ) -class IndexedElemwise(OpFromGraph): +class FusedElemwise(OpFromGraph): """Fuse indexed reads and updates into a single Elemwise iteration loop. Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) and @@ -250,28 +311,66 @@ class IndexedElemwise(OpFromGraph): Examples:: tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")] + + reduced_outputs : tuple of ((scalar_op, axes, identity, acc_dtype) | None) + One entry per (inner Elemwise / outer op) output position. + ``None`` if the output is not a reduction. + Otherwise the output is the result of reducing the inner Elemwise output + with ``CAReduce(scalar_op)`` over ``axes``: + + - ``scalar_op``: the commutative/associative binary scalar op of the + reduction (e.g. ``add`` for sum, ``mul`` for prod, ``maximum`` for max). + - ``axes``: tuple of reduced batch axes (in the inner Elemwise's dim space). + - ``identity``: the reduction identity, used to seed the accumulator buffer. + - ``acc_dtype``: dtype the accumulation is carried out in (the Numba loop + accumulates in this dtype, then casts to the output dtype). + + The inner fgraph still holds a faithful ``CAReduce(Elemwise(...))`` so + non-Numba backends evaluate correctly via ``OpFromGraph.perform``; only + the Numba backend reads this spec to fuse the reduction into the loop. + + Examples:: + + sum(exp(x)) → reduced_outputs=[(add, (0,), 0.0, "float64")] """ - def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): + def __init__( + self, + *args, + indexed_inputs=(), + indexed_outputs=(), + reduced_outputs=(), + **kwargs, + ): self.indexed_inputs = indexed_inputs self.indexed_outputs = indexed_outputs + self.reduced_outputs = reduced_outputs # A read buffer can occupy multiple input slots (e.g. read through # several indices); construct_nominal_fgraph dedupes those to one # nominal, leaving the extra slots as unused NominalVariables, which is # safe because reads don't destroy. Write targets always get their own - # fresh inner input (see FuseIndexedElemwise) so a destroyed buffer is + # fresh inner input (see FuseElemwise) so a destroyed buffer is # never deduped onto a read source. super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) def __str__(self): + elemwise_str = "Elemwise" for node in self.fgraph.apply_nodes: if isinstance(node.op, Elemwise): - return f"IndexedElemwise{{{node.op!s}}}" - return "IndexedElemwise" - - -@op_debug_information.register(IndexedElemwise) -def _op_debug_information_IndexedElemwise(op, node): + elemwise_str = str(node.op) + break + reductions = [ + f"{type(spec[0]).__name__.lower()}@{spec[1]}" + for spec in self.reduced_outputs + if spec is not None + ] + if reductions: + return f"FusedElemwise{{{elemwise_str}, reduce[{', '.join(reductions)}]}}" + return f"FusedElemwise{{{elemwise_str}}}" + + +@op_debug_information.register(FusedElemwise) +def _op_debug_information_FusedElemwise(op, node): info = {} n_idx = len(op.indexed_inputs) @@ -317,10 +416,20 @@ def _op_debug_information_IndexedElemwise(op, node): f"indexed {mode} ({buf_label}, {idx_label})" ) + # Annotate reduced outputs + for out_idx, spec in enumerate(op.reduced_outputs): + if spec is None or out_idx >= len(node.outputs): + continue + scalar_op, axes, _identity, acc_dtype = spec + info[node.outputs[out_idx]] = ( + f"reduced[{type(scalar_op).__name__.lower()}] " + f"over axes {axes} acc={acc_dtype}" + ) + return {node: info} -class FuseIndexedElemwise(GraphRewriter): +class FuseElemwise(GraphRewriter): """Fuse indexed reads and indexed updates into Elemwise loops. Absorbs single-client ``AdvancedSubtensor1`` on inputs (indexed reads) @@ -377,10 +486,11 @@ def _extract_idx_axis_pairs(node, *, write=False): @staticmethod def _duplicate_multi_client_outputs(node, multi_client_outs): - """Add duplicate outputs for Elemwise results that have both write and non-write consumers. + """Add duplicate outputs for Elemwise results whose consumers must be split. - Returns ``(new_node, dup_map)`` where *dup_map* maps each original - output index to its duplicate position. + Used when an output is both written/reduced and consumed directly, or is + consumed by several reductions. Returns ``(new_node, dup_map)`` where + *dup_map* maps each original output index to its duplicate position. """ scalar_op = node.op.scalar_op if isinstance(scalar_op, Composite): @@ -399,7 +509,11 @@ def _duplicate_multi_client_outputs(node, multi_client_outs): s_outputs.append(s_outputs[out_idx]) new_scalar_op = Composite(s_inputs, s_outputs) - new_node = Elemwise(new_scalar_op).make_node(*node.inputs) + # Duplicates are appended after the original outputs, so the node's + # inplace pattern carries over unchanged (duplicates get no entries). + new_node = Elemwise(new_scalar_op, dict(node.op.inplace_pattern)).make_node( + *node.inputs + ) return new_node, dup_map @staticmethod @@ -574,7 +688,49 @@ def apply(self, fgraph): idx_groups[idx_axis_pair][1].append(out_idx) write_targets[out_idx] = client_node - if not idx_groups: + # Find reductions to fuse: an Elemwise output whose sole client is an + # eligible CAReduce. Outputs that are write targets are excluded (an + # output can't be both an indexed write and a reduction). Outputs + # with extra (non-reduce) clients are handled by duplication below. + reduced_outputs = {} # out_idx -> car_node + extra_reduction = ( + None # (out_idx, car_node) when reductions share an output + ) + for out_idx, out in enumerate(node.outputs): + if out_idx in write_targets: + continue + car_clients = [ + c + for c, _ in fgraph.clients[out] + if isinstance(c.op, CAReduce) + and isinstance(c.op.scalar_op, _REDUCE_SCALAR_OPS) + ] + if not car_clients: + continue + if len(car_clients) > 1: + extra_reduction = (out_idx, car_clients[1]) + break + reduced_outputs[out_idx] = car_clients[0] + + # An output consumed by several eligible reductions: peel one reduction + # per pass onto a duplicate output, until each reduction has its own copy + if extra_reduction is not None: + out_idx, car_node = extra_reduction + new_node, dup_map = self._duplicate_multi_client_outputs( + node, {out_idx} + ) + replacements = list( + zip(node.outputs, new_node.outputs[: len(node.outputs)]) + ) + new_reduced = car_node.op(new_node.outputs[dup_map[out_idx]]) + replacements.append((car_node.outputs[0], new_reduced)) + fgraph.replace_all( + replacements, reason="fuse_elemwise_split_multi_reductions" + ) + worklist.append(new_node) + continue + + if not idx_groups and not reduced_outputs: continue if must_transpose_write_axes: @@ -584,40 +740,53 @@ def apply(self, fgraph): assert replacements fgraph.replace_all( replacements, - reason="fuse_indexed_elemwise_move_write_axes", + reason="fuse_elemwise_move_write_axes", ) worklist.append(node) continue - indexed_reads = {i for reads, _ in idx_groups.values() for i in reads} + # If a reduced output also feeds non-reduce consumers, duplicate it via + # Composite so the reduction consumes the duplicate while the original + # stays materialised for the other consumers. We still fuse the + # reduction loop, even if the full output must also be produced. + # This runs before the strip-inplace pass below, so any reduced output + # reaching it is sole-client (truly de-materialized) and an inplace on a + # materialized original survives the fusion. + def _has_non_reduce_clients(out_idx): + car_node = reduced_outputs[out_idx] + return any( + c is not car_node for c, _ in fgraph.clients[node.outputs[out_idx]] + ) - # If any inplace targets an indexed-read input, - # strip and re-run inplace with those inputs protected - if any( - inp_idx in indexed_reads for inp_idx in node.op.inplace_pattern.values() - ): - stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) - fgraph.replace_all( - zip(node.outputs, stripped_node.outputs), - reason="fuse_indexed_elemwise_strip_inplace", + if reduce_and_direct_use_outs := { + out_idx + for out_idx in reduced_outputs + if _has_non_reduce_clients(out_idx) + }: + new_node, dup_map = self._duplicate_multi_client_outputs( + node, reduce_and_direct_use_outs ) - protected = frozenset(stripped_node.inputs[i] for i in indexed_reads) - # try_inplace_on_node does its own fgraph.replace_all internally, - # so the returned node is already in the fgraph - new_inplace_node = InplaceElemwiseOptimizer().try_inplace_on_node( - fgraph, - stripped_node, - reason="fuse_indexed_elemwise_inplace_read_buffers", - extra_protected_inputs=protected, + replacements = list( + zip(node.outputs, new_node.outputs[: len(node.outputs)]) ) - worklist.append(new_inplace_node) + for out_idx, dup_idx in dup_map.items(): + car_node = reduced_outputs[out_idx] + new_reduced = car_node.op(new_node.outputs[dup_idx]) + replacements.append((car_node.outputs[0], new_reduced)) + fgraph.replace_all( + replacements, + reason="fuse_reduce_and_direct_outputs", + ) + worklist.append(new_node) continue # If any indexed-write output also has other consumers, # duplicate it via Composite so the write replaces the duplicate # while the original stays available for non-write consumers. # We still avoid one extra write loop, - # even if we can't skip the output materialization altogether + # even if we can't skip the output materialization altogether. + # Like the reduce duplication above, this runs before the strip-inplace + # pass, so an inplace on the materialized original survives the fusion. def _has_non_write_clients(out_idx): update = write_targets[out_idx] for c, _ in fgraph.clients[node.outputs[out_idx]]: @@ -653,17 +822,67 @@ def _has_non_write_clients(out_idx): replacements.append((update_node.outputs[0], new_update_out)) fgraph.replace_all( replacements, - reason="fuse_indexed_elemwise_write_and_direct_outputs", + reason="fuse_elemwise_write_and_direct_outputs", ) worklist.append(new_node) continue + indexed_reads = {i for reads, _ in idx_groups.values() for i in reads} + + # If any inplace targets an indexed-read input, or claims an output that is + # not materialized as a plain output — a fused reduction (reusing the input + # buffer as the reduce accumulator would skip the identity init and + # accumulate on top of stale input values) or an indexed write (the loop + # writes to the write buffer instead, so the input destruction would happen + # only in the Python-mode fallback, undeclared by the outer destroy map) — + # strip and re-run inplace with those inputs protected and outputs excluded. + # The duplications above ran first, so such outputs here are sole-client. + if any( + inp_idx in indexed_reads for inp_idx in node.op.inplace_pattern.values() + ) or any( + out_idx in reduced_outputs or out_idx in write_targets + for out_idx in node.op.inplace_pattern + ): + stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) + fgraph.replace_all( + zip(node.outputs, stripped_node.outputs), + reason="fuse_elemwise_strip_inplace", + ) + optimizer = InplaceElemwiseOptimizer() + protected = optimizer._get_protected_inputs(fgraph) + protected.update(stripped_node.inputs[i] for i in indexed_reads) + # Candidates are plain-Elemwise (output, input) pairs; exclude + # outputs the fusion is about to consume as indexed writes + candidate_pairs = [ + pair + for pair in optimizer.filter_candidate_pairs( + fgraph, stripped_node, protected + ) + if pair[0][0] not in reduced_outputs + and pair[0][0] not in write_targets + ] + # try_inplace_on_node does its own fgraph.replace_all internally, + # so the returned node is already in the fgraph + new_inplace_node = optimizer.try_inplace_on_node( + fgraph, + stripped_node, + candidate_pairs=candidate_pairs, + reason="fuse_elemwise_inplace_read_buffers", + ) + worklist.append(new_inplace_node) + continue + idx_vars = [idx for idx, _axis in idx_groups] + # The strip-inplace pass above guarantees that reduced and indexed-write + # outputs carry no inplace + assert not any( + out_idx in reduced_outputs or out_idx in write_targets + for out_idx in node.op.inplace_pattern + ) fgraph_destroy_map = { out_idx: [inp_idx] for out_idx, inp_idx in node.op.inplace_pattern.items() - if out_idx not in write_targets } # Fgraph inputs: substitute indexed sources back to their @@ -715,6 +934,41 @@ def _has_non_write_clients(out_idx): fgraph_outputs[out_idx] = write_out fgraph_destroy_map[out_idx] = [target_pos] + # Inner fgraph reduced outputs: wrap the Elemwise output in the real + # CAReduce so non-Numba backends compute it faithfully via perform. + # reduced_spec carries the (scalar_op, axes, identity, acc_dtype) the + # Numba backend reads to fuse the reduction into the loop instead. + reduced_spec_by_idx = {} + for out_idx, car_node in sorted(reduced_outputs.items()): + car_op = car_node.op + ndim = node.outputs[out_idx].type.ndim + axes = ( + tuple(range(ndim)) + if car_op.axis is None + else tuple(sorted(car_op.axis)) + ) + acc_dtype = ( + car_op.acc_dtype + if car_op.acc_dtype is not None + else car_node.outputs[0].type.dtype + ) + reduced_spec_by_idx[out_idx] = ( + car_op.scalar_op, + axes, + car_op.scalar_op.identity, + acc_dtype, + ) + fgraph_outputs[out_idx] = car_op(node.outputs[out_idx]) + + reduced_spec = ( + tuple( + reduced_spec_by_idx.get(out_idx) + for out_idx in range(len(node.outputs)) + ) + if reduced_outputs + else () + ) + # indexed_inputs_spec: ((read_positions, axis) | None, ...) # indexed_outputs_spec: ((write_positions, axis, "inc"|"set") | None, ...) indexed_inputs_spec = tuple( @@ -737,12 +991,13 @@ def _has_non_write_clients(out_idx): val = outer_write_targets.get(i, inp) outer_inputs.append(val.copy() if i in copy_positions else val) - new_outs = IndexedElemwise( + new_outs = FusedElemwise( fgraph_inputs, fgraph_outputs, destroy_map=fgraph_destroy_map, indexed_inputs=indexed_inputs_spec, indexed_outputs=indexed_outputs_spec, + reduced_outputs=reduced_spec, )(*outer_inputs, return_list=True) replacements = [] @@ -751,6 +1006,10 @@ def _has_non_write_clients(out_idx): replacements.append( (write_targets[out_idx].outputs[0], new_outs[out_idx]) ) + elif out_idx in reduced_outputs: + replacements.append( + (reduced_outputs[out_idx].outputs[0], new_outs[out_idx]) + ) else: replacements.append((node.outputs[out_idx], new_outs[out_idx])) @@ -766,9 +1025,9 @@ def _has_non_write_clients(out_idx): continue -indexed_elemwise_optdb.register( - "fuse_indexed_elemwise", - FuseIndexedElemwise(), +fused_elemwise_optdb.register( + "fuse_elemwise", + FuseElemwise(), "numba", position=1, ) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index ecb3435030..d2c41aeef3 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -1,4 +1,4 @@ -import pytensor.tensor.rewriting.indexed_elemwise # noqa: F401 +import pytensor.tensor.rewriting.fused_elemwise # noqa: F401 from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter diff --git a/tests/benchmarks/test_gather_fusion.py b/tests/benchmarks/test_gather_fusion.py index 0c634e354d..46d9446343 100644 --- a/tests/benchmarks/test_gather_fusion.py +++ b/tests/benchmarks/test_gather_fusion.py @@ -12,7 +12,7 @@ import pytensor.tensor as pt from pytensor import config from pytensor.compile.mode import get_mode -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 @@ -50,11 +50,11 @@ def read_benchmark_setup(request): ) assert any( - isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() - ), "IndexedElemwise not found in fused graph" + isinstance(n.op, FusedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "FusedElemwise not found in fused graph" assert not any( - isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() - ), "IndexedElemwise found in unfused graph" + isinstance(n.op, FusedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "FusedElemwise found in unfused graph" rng = np.random.default_rng(1) vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] @@ -117,15 +117,15 @@ def write_benchmark_setup(request): ) assert any( - isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() - ), "IndexedElemwise not found in fused graph" + isinstance(n.op, FusedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "FusedElemwise not found in fused graph" assert not any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn_fused.maker.fgraph.toposort() ), "AdvancedIncSubtensor1 still present in fused graph" assert not any( - isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() - ), "IndexedElemwise found in unfused graph" + isinstance(n.op, FusedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "FusedElemwise found in unfused graph" rng = np.random.default_rng(1) vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_fused_elemwise.py similarity index 61% rename from tests/link/numba/test_indexed_elemwise.py rename to tests/link/numba/test_fused_elemwise.py index 22ffbe5927..4bf1e83a8d 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_fused_elemwise.py @@ -1,11 +1,12 @@ -"""Tests for IndexedElemwise fusion (indexed reads and updates in Elemwise loops).""" +"""Tests for FusedElemwise (indexed reads/writes and reductions fused into Elemwise loops).""" import numpy as np import pytest import pytensor.tensor as pt from pytensor import Mode, function, get_mode -from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.elemwise import CAReduce, Elemwise +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, AdvancedSubtensor, @@ -26,9 +27,9 @@ def fused_and_unfused(inputs, output): def assert_fused(fn): - """Assert that the compiled graph contains an IndexedElemwise node.""" - assert any(isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort()), ( - "IndexedElemwise not found in fused graph" + """Assert that the compiled graph contains a FusedElemwise node.""" + assert any(isinstance(n.op, FusedElemwise) for n in fn.maker.fgraph.toposort()), ( + "FusedElemwise not found in fused graph" ) @@ -251,7 +252,7 @@ def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): fn, fn_u = fused_and_unfused([x, y, target], out) # Write not fused — the Elemwise loop dim is the non-indexed axis, # not the indexed axis. Read fusion may still create an - # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. + # FusedElemwise, but the AdvancedIncSubtensor1 must remain outside. assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() ) @@ -360,6 +361,71 @@ def test_inc_subtensor(self): fn(xv, yv, tv.copy()), fn_u(xv, yv, tv.copy()), rtol=1e-10 ) + def test_write_of_inplace_elemwise(self): + """Inplace must not survive on a write-target output. + + The loop writes the elementwise result to the write buffer, never to the + inplaced input, so a surviving inner inplace entry would destroy the dot + intermediate undeclared in the Python-mode fallback. + """ + rng = np.random.default_rng(9) + idx = rng.integers(8, size=30).astype(np.int64) + x = pt.matrix("x", shape=(30, 4)) + y = pt.matrix("y", shape=(4, 4)) + z = pt.matrix("z", shape=(30, 4)) + t = pt.matrix("t", shape=(8, 4)) + out = t[idx].inc((x @ y) * z) + fn, fn_u = fused_and_unfused([x, y, z, t], out) + assert_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert not any( + n.op.inplace_pattern + for n in node.op.fgraph.toposort() + if isinstance(n.op, Elemwise) + ) + xv, yv, zv, tv = ( + rng.normal(size=(30, 4)), + rng.normal(size=(4, 4)), + rng.normal(size=(30, 4)), + rng.normal(size=(8, 4)), + ) + np.testing.assert_allclose( + fn(xv, yv, zv, tv.copy()), fn_u(xv, yv, zv, tv.copy()), rtol=1e-10 + ) + + def test_write_with_direct_use_keeps_inplace(self): + """Inplace survives on an output that is both written and used directly. + + The write-and-direct duplication keeps the original output materialized + (the write consumes a duplicate), so the inplace claimed on the dot + intermediate stays valid and must not be stripped. + """ + rng = np.random.default_rng(10) + idx = rng.integers(8, size=30).astype(np.int64) + x = pt.matrix("x", shape=(30, 4)) + y = pt.matrix("y", shape=(4, 4)) + z = pt.matrix("z", shape=(30, 4)) + t = pt.matrix("t", shape=(8, 4)) + w = (x @ y) * z + fn, fn_u = fused_and_unfused([x, y, z, t], [t[idx].inc(w), w]) + assert_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + # Two destroy entries: the write buffer, and the dot intermediate kept + # inplace by the materialized output + assert len(node.op.destroy_map) == 2 + xv, yv, zv, tv = ( + rng.normal(size=(30, 4)), + rng.normal(size=(4, 4)), + rng.normal(size=(30, 4)), + rng.normal(size=(8, 4)), + ) + for res, res_u in zip(fn(xv, yv, zv, tv.copy()), fn_u(xv, yv, zv, tv.copy())): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + def test_set_subtensor(self): rng = np.random.default_rng(42) idx = rng.integers(85, size=919).astype(np.int64) @@ -466,7 +532,7 @@ def test_write_target_aliases_read_source(self, read_idx, write_idx): write_idx = np.array(write_idx, dtype=np.int64) out = b[write_idx].set(b[read_idx] * 2.0) fn, fn_u = fused_and_unfused([x], out) - # The read fuses into an IndexedElemwise; the aliasing write stays external. + # The read fuses into a FusedElemwise; the aliasing write stays external. assert_fused(fn) assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() @@ -500,7 +566,7 @@ def test_non_inplace_aliasing_write_preserves_input(self): # The inner write must destroy exactly the input the op's destroy_map names. [node] = [ - n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) ] [(_out_idx, [destroyed_pos])] = node.op.destroy_map.items() [inner_write] = [ @@ -630,3 +696,325 @@ def test_loop_shape_regression(self): res = f(beta=test_beta, mask=test_mask) ref_res = ref_f(beta=test_beta, mask=test_mask) np.testing.assert_allclose(res, ref_res, strict=True) + + +def assert_reduce_fused(fn): + """Assert the graph contains a FusedElemwise with a fused reduction.""" + nodes = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise)] + assert nodes, "FusedElemwise not found in fused graph" + assert any(any(r is not None for r in n.op.reduced_outputs) for n in nodes), ( + "No fused reduction (reduced_outputs) found" + ) + + +class TestReductionFusion: + """Reductions (CAReduce) fused into the Elemwise loop, no indexing.""" + + @pytest.mark.parametrize("axis", [None, 0, 1, 2, (0, 2), (0, 1), (1, 2)], ids=str) + def test_sum_axes(self, axis): + rng = np.random.default_rng(0) + x = pt.tensor3("x") + y = pt.tensor3("y") + out = pt.sum(pt.exp(x) + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(3, 4, 5)), rng.normal(size=(3, 4, 5)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_prod(self, axis): + rng = np.random.default_rng(1) + x = pt.matrix("x") + out = pt.prod(pt.exp(x * 0.1), axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(4, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-8) + + @pytest.mark.parametrize("reduce_fn", [pt.max, pt.min], ids=["max", "min"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_max_min(self, reduce_fn, axis): + rng = np.random.default_rng(2) + x = pt.matrix("x") + y = pt.matrix("y") + out = reduce_fn(x + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(6, 7)), rng.normal(size=(6, 7)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("reduce_fn", [pt.all, pt.any], ids=["all", "any"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_all_any(self, reduce_fn, axis): + rng = np.random.default_rng(3) + x = pt.matrix("x", dtype="bool") + y = pt.matrix("y", dtype="bool") + out = reduce_fn(x & y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv = rng.integers(0, 2, size=(4, 5)).astype(bool) + yv = rng.integers(0, 2, size=(4, 5)).astype(bool) + np.testing.assert_array_equal(fn(xv, yv), fn_u(xv, yv)) + + @pytest.mark.parametrize("dtype", ["int8", "int32", "uint8"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_acc_dtype_widening(self, dtype, axis): + """Sum of small int dtype accumulates in a wider acc_dtype.""" + rng = np.random.default_rng(4) + x = pt.matrix("x", dtype=dtype) + out = pt.sum(x + x, axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + info = np.iinfo(dtype) + xv = rng.integers(0, min(info.max // 2, 50), size=(40, 40)).astype(dtype) + np.testing.assert_array_equal(fn(xv), fn_u(xv)) + # Result must be the wide acc dtype, not overflow the input dtype + assert fn(xv).dtype == fn_u(xv).dtype + + def test_scalar_and_1d(self): + rng = np.random.default_rng(5) + x = pt.vector("x") + out = pt.sum(pt.exp(x)) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(17,)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_non_c_contiguous_input(self): + """Reduction over a transposed (non-C-contiguous) intermediate.""" + rng = np.random.default_rng(6) + x = pt.matrix("x") + out = pt.sum((x + 1.0).T, axis=0) + fn, fn_u = fused_and_unfused([x], out) + xv = rng.normal(size=(8, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_reduce_of_inplace_elemwise(self): + """Inplace must not survive on an output fused as a reduction. + + The dot output is a destroyable intermediate the inplace pass claims for the + Mul before fusion. If the fused op kept that destroy entry, the reduce + accumulator would alias the input buffer and skip the identity init, + folding stale input values into the result. + """ + rng = np.random.default_rng(7) + x, y, z = pt.matrix("x"), pt.matrix("y"), pt.matrix("z") + out = ((x @ y) * z).sum() + fn, fn_u = fused_and_unfused([x, y, z], out) + assert_reduce_fused(fn) + for node in fn.maker.fgraph.toposort(): + if isinstance(node.op, FusedElemwise): + assert not any( + out_idx in node.op.destroy_map + for out_idx, r in enumerate(node.op.reduced_outputs) + if r is not None + ) + xv, yv, zv = (rng.normal(size=(4, 4)) for _ in range(3)) + np.testing.assert_allclose(fn(xv, yv, zv), fn_u(xv, yv, zv), rtol=1e-10) + + def test_reduce_with_direct_use_keeps_inplace(self): + """Inplace survives on an output that is both reduced and used directly. + + The reduce-and-direct duplication keeps the original output materialized + (the CAReduce consumes a duplicate), so the inplace claimed on the dot + intermediate stays valid and must not be stripped. + """ + rng = np.random.default_rng(8) + x, y, z = pt.matrix("x"), pt.matrix("y"), pt.matrix("z") + w = (x @ y) * z + fn, fn_u = fused_and_unfused([x, y, z], [w.sum(), w]) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + reduced_idxs = { + i for i, r in enumerate(node.op.reduced_outputs) if r is not None + } + # The materialized output keeps its destroy entry; the reduced one has none + assert node.op.destroy_map + assert not set(node.op.destroy_map) & reduced_idxs + xv, yv, zv = (rng.normal(size=(4, 4)) for _ in range(3)) + for res, res_u in zip(fn(xv, yv, zv), fn_u(xv, yv, zv)): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + + +class TestReductionWithIndexing: + """Reductions composed with indexed reads in a single fused loop.""" + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_gather(self, axis): + rng = np.random.default_rng(10) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + out = pt.sum(x[idx] + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, yv, idxv), fn_u(xv, yv, idxv), rtol=1e-10) + + def test_max_gather(self): + rng = np.random.default_rng(11) + x = pt.matrix("x") + idx = pt.lvector("idx") + out = pt.max(pt.exp(x[idx]), axis=0) + fn, fn_u = fused_and_unfused([x, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + + @pytest.mark.parametrize( + "make_out", + [ + lambda x, idx: pt.sum(x[idx]), + lambda x, idx: pt.max(x[idx], axis=0), + lambda x, idx: pt.sum(x[:, idx]), + ], + ids=["sum_all", "max_axis0", "sum_axis1_gather"], + ) + def test_reduce_of_bare_gather(self, make_out): + """Reduction of an indexed read with no elemwise in between. + + wrap_reduced_gather_in_elemwise inserts an identity Elemwise so the + gather and the reduction still collapse into one fused loop. + """ + rng = np.random.default_rng(13) + x = pt.matrix("x") + idx = pt.lvector("idx") + fn, fn_u = fused_and_unfused([x, idx], make_out(x, idx)) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert any(spec is not None for spec in node.op.indexed_inputs) + xv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 5, size=4) + np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + + def test_reduce_of_bare_nd_gather(self): + """Reduction of an ND-index read (Reshape-flattened form).""" + rng = np.random.default_rng(14) + x = pt.matrix("x") + mat_idx = pt.lmatrix("mat_idx") + fn, fn_u = fused_and_unfused([x, mat_idx], pt.sum(x[mat_idx])) + assert_reduce_fused(fn) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert any(spec is not None for spec in node.op.indexed_inputs) + xv = rng.normal(size=(8, 5)) + mv = rng.integers(0, 8, size=(3, 2)) + np.testing.assert_allclose(fn(xv, mv), fn_u(xv, mv), rtol=1e-10) + + def test_gather_scatter_and_reduce_mix(self): + """Gather + elemwise + scatter + reduce all fuse into one loop. + + Sibling fusion merges the two elemwise expressions into one multi-output + Composite; FuseElemwise then absorbs the indexed read, the indexed + write, and the reduction into a single FusedElemwise. + """ + rng = np.random.default_rng(12) + x = pt.matrix("x") + y = pt.matrix("y") + t = pt.matrix("t") + idx = pt.lvector("idx") + scattered = t[idx].inc(x[idx] * y) + reduced = pt.sum(x[idx] + y) + fn, fn_u = fused_and_unfused([x, y, t, idx], [scattered, reduced]) + nodes = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert len(nodes) == 1 + [node] = nodes + assert any(spec is not None for spec in node.op.indexed_inputs) + assert any(spec is not None for spec in node.op.indexed_outputs) + assert any(r is not None for r in node.op.reduced_outputs) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + tv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 8, size=4) + for res, res_u in zip( + fn(xv, yv, tv.copy(), idxv), fn_u(xv, yv, tv.copy(), idxv) + ): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + + +class TestReductionMultiOutput: + """Reduction plus direct use of the same Elemwise output (duplication).""" + + def test_sum_and_direct(self): + rng = np.random.default_rng(20) + x = pt.matrix("x") + y = pt.matrix("y") + f = pt.exp(x) + y + out = [pt.sum(f, axis=0), f] + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(4, 5)), rng.normal(size=(4, 5)) + r, ru = fn(xv, yv), fn_u(xv, yv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + def test_two_reductions_same_source(self): + """sum and max of the same elemwise output (two CAReduce clients).""" + rng = np.random.default_rng(21) + x = pt.matrix("x") + f = pt.exp(x) + out = [pt.sum(f, axis=0), pt.max(f, axis=0)] + fn, fn_u = fused_and_unfused([x], out) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert sum(r is not None for r in node.op.reduced_outputs) == 2 + assert not any(isinstance(n.op, CAReduce) for n in fn.maker.fgraph.toposort()) + xv = rng.normal(size=(4, 5)) + r, ru = fn(xv), fn_u(xv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + def test_three_reductions_and_direct_use(self): + """Three reductions plus a direct consumer of one shared output.""" + rng = np.random.default_rng(22) + x = pt.matrix("x") + y = pt.matrix("y") + f = x * y + out = [pt.sum(f), pt.max(f, axis=0), pt.prod(f, axis=1), f] + fn, fn_u = fused_and_unfused([x, y], out) + [node] = [ + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ] + assert sum(r is not None for r in node.op.reduced_outputs) == 3 + assert not any(isinstance(n.op, CAReduce) for n in fn.maker.fgraph.toposort()) + xv, yv = rng.normal(size=(4, 5)), rng.normal(size=(4, 5)) + for res, res_u in zip(fn(xv, yv), fn_u(xv, yv)): + np.testing.assert_allclose(res, res_u, rtol=1e-10) + + +class TestReductionPythonMode: + """The fused op evaluates correctly outside JIT (OpFromGraph.perform).""" + + def test_perform_matches(self): + rng = np.random.default_rng(30) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + fn = function( + [x, y, idx], pt.sum(x[idx] + y, axis=1), mode=NUMBA_MODE, trust_input=True + ) + node = next( + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, FusedElemwise) + ) + # Re-apply the exact fused op and evaluate it via OpFromGraph.perform + fresh = [inp.type() for inp in node.inputs] + perform_fn = function( + fresh, node.op(*fresh, return_list=True), mode="FAST_COMPILE" + ) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose( + perform_fn(xv, yv, idxv)[0], np.sum(xv[idxv] + yv, axis=1), rtol=1e-10 + ) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7ac8f4c65b..dac8690abd 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -249,7 +249,11 @@ def _raise_on_opt_error(self): "add_mul_fusion", "inplace", ], - exclude=["cxx_only", "BlasOpt"], + # Exclude both careduce-fusion paths so reductions stay unfused here: + # cxx_only covers local_careduce_fusion (C backend); the indexed/reduce + # fusion is the Numba-specific equivalent. This class tests the generic + # Composite fusion, independent of the active backend. + exclude=["cxx_only", "BlasOpt", "fuse_indexed_into_elemwise"], ) mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 7042ddf598..a8fc733667 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2827,7 +2827,13 @@ class TestLocalSumProd: """Test sum/prod rewrites.""" def setup_method(self): - self.mode = get_default_mode().including("canonicalize", "specialize") + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize") + .excluding("fuse_indexed_into_elemwise") + ) def test_local_sum_prod_of_scalar_mul(self): # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and @@ -3336,7 +3342,9 @@ def test_local_prod_of_div(self): c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX) d_val = np.asarray(rng.standard_normal(), config.floatX) - default_mode = get_default_mode() + # Exclude the Numba reduction fusion so the outer reduction op stays + # visible in the toposort for the structural assertions below. + default_mode = get_default_mode().excluding("fuse_indexed_into_elemwise") # `FusionOptimizer` is included to make sure that `expected_outer_operator` # remains the same for all rewrite modes. mode_with_rewrite = default_mode.including( @@ -3393,8 +3401,12 @@ def test_local_prod_of_div(self): class TestLocalReduce: def setup_method(self): - self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize" + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize", "uncanonicalize") + .excluding("fuse_indexed_into_elemwise") ) def test_local_reduce_broadcast_all_0(self): diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index ecb671e3ce..98b7e71219 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -18,6 +18,7 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, exp, sqr +from pytensor.tensor.rewriting.fused_elemwise import FusedElemwise from pytensor.tensor.rewriting.subtensor import ( _slice_to_arange, local_add_of_sparse_write, @@ -490,11 +491,9 @@ def test_local_useless_subtensor_6(self, idx, res): else: # Arange-with-offset indices get rewritten to a Subtensor slice; # other advanced indices stay as AdvancedSubtensor1 (or get - # absorbed into IndexedElemwise by FuseIndexedElemwise). - from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise - + # absorbed into FusedElemwise by FuseElemwise). assert any( - isinstance(node.op, AdvancedSubtensor1 | Subtensor | IndexedElemwise) + isinstance(node.op, AdvancedSubtensor1 | Subtensor | FusedElemwise) for node in prog ) diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index 9d5011b6db..e842703d45 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -23,8 +23,12 @@ class TestMinMax: def setup_method(self): - self.mode = pytensor.compile.mode.get_default_mode().including( - "canonicalize", "fast_run" + # Exclude the Numba reduction fusion so the Max/Min CAReduce nodes stay + # visible in the toposort for the structural assertions below. + self.mode = ( + pytensor.compile.mode.get_default_mode() + .including("canonicalize", "fast_run") + .excluding("fuse_indexed_into_elemwise") ) def test_optimization_min(self):