Skip to content

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015

Draft
ricardoV94 wants to merge 5 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion
Draft

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015
ricardoV94 wants to merge 5 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 29, 2026

Summary

Introduce IndexedElemwise, an OpFromGraph that wraps AdvancedSubtensor + Elemwise + AdvancedIncSubtensor subgraphs so the Numba backend can generate a single loop with indirect indexing, avoiding materializing AvancedSubtensor input arrays, and writing directly on the output buffer, doing the job of AdvancedIncSubtensor in the same loop, without having to loop again through the intermediate elemwise output

Commit 1 fuses indexed reads (AdvancedSubtensor1 on inputs).
Commit 2 fuses indexed updates (AdvancedIncSubtensor1 on outputs).
Commit 3 extends to AdvancedSubtensor inputs, on arbitrary (1d) indexed (consecutive) axes

Motivation

In hierarchical models with mu = beta[idx] * x + ..., the logp+gradient graph combines indexed reads and indexed updates in the same Elemwise (the forward reads county-level parameters via an index, and the gradient accumulates back into county-level buffers via the same index).

A simler

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

numba_mode = get_mode("NUMBA")
numba_mode_before = numba_mode.excluding("fuse_indexed_elemwise")

x = pt.vector("x")
idx = pt.vector("idx", dtype=int)
value = pt.vector("value")

y = pt.zeros(100)
out = ((x[idx] - value) ** 2).sum()
grad_wrt_x = pt.grad(out, x)
fn_before = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode_before, trust_input=True)
fn_before.dprint(print_op_info=True, print_destroy_map=True)
# Sum{axes=None} [id A] 5
#  └─ Composite{...}.0 [id B] d={0: [0]} 1
#     ├─ AdvancedSubtensor1 [id C] 0
#     │  ├─ x [id D]
#     │  └─ idx [id E]
#     └─ value [id F]
# AdvancedIncSubtensor1{inplace,inc} [id G] d={0: [0]} 4
#  ├─ Alloc [id H] 3
#  │  ├─ [0.] [id I]
#  │  └─ Shape_i{0} [id J] 2
#  │     └─ x [id D]
#  ├─ Composite{...}.1 [id B] d={0: [0]} 1
#  │  └─ ···
#  └─ idx [id E]

# Inner graphs:

# Composite{...} [id B] d={0: [0]}
#  ← sqr [id K] 'o0'
#     └─ sub [id L] 't5'
#        ├─ i0 [id M]
#        └─ i1 [id N]
#  ← mul [id O] 'o1'
#     ├─ 2.0 [id P]
#     └─ sub [id L] 't5'
#        └─ ···

fn = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode, trust_input=True)
fn.dprint(print_op_info=True, print_destroy_map=True)

# Sum{axes=None} [id A] 3
#  └─ IndexedElemwise{Composite{...}}.0 [id B] d={1: [3]} 2
#     ├─ x [id C] (indexed read (idx_0))
#     ├─ value [id D]
#     ├─ idx [id E] (idx_0)
#     └─ Alloc [id F] 1 (buf_0)
#        ├─ [0.] [id G]
#        └─ Shape_i{0} [id H] 0
#           └─ x [id C]
# IndexedElemwise{Composite{...}}.1 [id B] d={1: [3]} 2 (indexed inc (buf_0, idx_0))
#  └─ ···

# Inner graphs:

# IndexedElemwise{Composite{...}} [id B] d={1: [3]}
#  ← Composite{...}.0 [id I]
#     ├─ AdvancedSubtensor1 [id J]
#     │  ├─ *0-<Vector(float64, shape=(?,))> [id K]
#     │  └─ *2-<Vector(int64, shape=(?,))> [id L]
#     └─ *1-<Vector(float64, shape=(?,))> [id M]
#  ← AdvancedIncSubtensor1{inplace,inc} [id N] d={0: [0]}
#     ├─ *3-<Vector(float64, shape=(?,))> [id O]
#     ├─ Composite{...}.1 [id I]
#     │  └─ ···
#     └─ *2-<Vector(int64, shape=(?,))> [id L]

# Composite{...} [id I]
#  ← sqr [id P] 'o0'
#     └─ sub [id Q] 't0'
#        ├─ i0 [id R]
#        └─ i1 [id S]
#  ← mul [id T] 'o1'
#     ├─ 2.0 [id U]
#     └─ sub [id Q] 't0'
#        └─ ···

x_test = np.arange(15, dtype="float64")
idx_test = np.random.randint(15, size=(10_000,))
value_test = np.random.normal(size=idx_test.shape)

logp_before, dlogp_before = fn_before(x_test, value_test, idx_test)
logp, dlogp = fn(x_test, value_test, idx_test)
np.testing.assert_allclose(logp_before, logp)
np.testing.assert_allclose(dlogp_before, dlogp)

%timeit fn_before(x_test, value_test, idx_test)  # 29.4 μs ± 2.57 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fn(x_test, value_test, idx_test)  # 13.8 μs ± 136 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Next step would be to also fuse the sum directly on the elemwise, so we end up with a single loop over the data. This is important as the sum can easily break our fusion, as we don't fuse if the elemwise output is needed elsewhere (like in a sum).

@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 6d875d8 to 0ad6e2e Compare March 29, 2026 18:14
@ricardoV94 ricardoV94 changed the title Numba: fuse AdvancedSubtensor+Elemwise Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor Mar 29, 2026
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch from 0ad6e2e to 9e32400 Compare March 30, 2026 21:12
Set view_map = True for when n_steps may be zero
Fuse single-client AdvancedSubtensor1 nodes into Elemwise loops,
replacing indirect array reads with a single iteration loop that
uses index arrays for input access.

Before (2 nodes):
  temp = x[idx]                    # AdvancedSubtensor1, shape (919,)
  result = temp + y                # Elemwise

After (1 fused loop, x is read directly via idx):
  for k in range(919):
      result[k] = x[idx[k]] + y[k]

- Introduce IndexedElemwise Op (in rewriting/indexed_elemwise.py)
- Add FuseIndexedElemwise rewrite with SequenceDB
- Merge _vectorized intrinsics into one with NO_SIZE/NO_INDEXED sentinels
- Fix Numba missing getitem(0d_array, Ellipsis)
- Index arrays participate in iter_shape with correct static bc
- zext for unsigned index types
- Add op_debug_information for dprint(print_op_info=True)
- Add correctness tests and benchmarks
Extend the IndexedElemwise fusion to also absorb
AdvancedIncSubtensor1 (indexed set/inc) on the output side.

Before (3 nodes):
  temp = Elemwise(x[idx], y)               # shape (919,)
  result = IncSubtensor(target, temp, idx)  # target shape (85,)

After (1 fused loop, target is an input):
  for k in range(919):
      target[idx[k]] += scalar_fn(x[idx[k]], y[k])

- FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers
- Reject fusion when val broadcasts against target's non-indexed axes
- store_core_outputs supports inc mode via o[...] += val
- Inner fgraph always uses inplace IncSubtensor
- op_debug_information shows buf_N / idx_N linkage
- Add indexed-update tests, broadcast guard test, and benchmarks
Support AdvancedSubtensor on any axis (not just axis 0) and multi-index
patterns like x[idx_row, idx_col] where multiple 1D index arrays address
consecutive source axes.

Arbitrary axis:
  x[:, idx] + y → fused loop with indirect indexing on axis 1

Multi-index:
  x[idx0, idx1] + y → out[i, j] = x[idx0[i], idx1[i], j] + y[i, j]

- Add undo_take_dimshuffle_for_fusion pre-fusion rewrite
- Generalize indexed_inputs encoding: ((positions, axis, idx_bc), ...)
- input_read_spec uses tuple of (idx_k, axis) pairs per input
- source_input_types for array struct access, input_types (effective)
  for core_ndim / _compute_vectorized_types
- n_index_loop_dims = max(idx.ndim for group) for future ND support
- Index arrays participate in iter_shape with correct per-index static bc
- Reject fusion when val broadcasts against target's non-indexed axes
- Add correctness, broadcast, and shape validation tests
Extend FusionOptimizer to merge independent subgraphs that share
inputs but have no producer-consumer edge (siblings like f(x) and g(x)).
The eager expansion only walks producer-consumer edges, missing these.

Also extract InplaceGraphOptimizer.try_inplace_on_node helper and
_insert_sorted_subgraph to deduplicate insertion-point logic.
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch from 9e32400 to e3c59c6 Compare March 31, 2026 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant