From c367625bc31c73534c5f986a9d20814837b90d63 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 28 Mar 2026 00:18:25 +0100 Subject: [PATCH 1/3] Scan rewrite sit_sot -> untraced_sit_sot when only one entry is kept Reduces read and write overhead --- pytensor/link/numba/dispatch/scan.py | 7 +- pytensor/scan/rewriting.py | 195 ++++++++++++++++++++++++++- tests/benchmarks/test_scan.py | 10 +- tests/link/numba/test_scan.py | 18 ++- tests/scan/test_rewriting.py | 48 +++++++ 5 files changed, 267 insertions(+), 11 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 763489fbf2..2f9901731e 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -16,7 +16,6 @@ ) from pytensor.link.numba.dispatch.string_codegen import create_tuple_string from pytensor.scan.op import Scan -from pytensor.tensor.type import TensorType def idx_to_str( @@ -308,8 +307,8 @@ def add_output_storage_post_proc_stmt( if outer_in_name not in outer_in_nit_sot_names: storage_name = outer_in_to_storage_name[outer_in_name] - is_tensor_type = isinstance(outer_in_var.type, TensorType) - if is_tensor_type: + is_tapped = outer_in_name not in outer_in_untraced_sit_sot_names + if is_tapped: storage_size_name = f"{outer_in_name}_len" storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]" input_taps = inner_in_names_to_input_taps[outer_in_name] @@ -352,7 +351,7 @@ def add_output_storage_post_proc_stmt( inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) - if output_idx in node.op.destroy_map or not is_tensor_type: + if output_idx in node.op.destroy_map or not is_tapped: storage_alloc_stmt = f"{storage_name} = {outer_in_name}" else: storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 8dcb89c9e3..ebd1636ed3 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1564,7 +1564,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # Recreate default buffers with new size if _is_default_scan_buffer(nw_input, taps): - extra_size = 1 if required_orphan else val - taps + if required_orphan: + extra_size = ( + 1 if backend_supports_output_pre_allocation else 0 + ) + else: + extra_size = val - taps nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) # Otherwise, just trim with a slice else: @@ -1766,6 +1771,186 @@ def scan_save_mem_no_prealloc(fgraph, node): ) +@node_rewriter([Scan]) +def scan_sit_sot_to_untraced(fgraph, node): + """Convert sit_sot with buffer size=1 to untraced_sit_sot. + + After scan_save_mem has reduced buffer sizes, sit_sot outputs that only + need one state stored (buffer size=1) can be converted to untraced_sit_sot, + which avoids the overhead of reading/writing circular buffers each iteration. + """ + op = node.op + info = op.info + + if info.n_sit_sot == 0: + return False + + outer_sitsot = op.outer_sitsot(node.inputs) + convertible = [ + idx for idx in range(info.n_sit_sot) if outer_sitsot[idx].type.shape[0] == 1 + ] + + if not convertible: + return False + + convertible_set = set(convertible) + + # Gather current inner inputs/outputs by category + inner_inputs = list(op.inner_inputs) + inner_outputs = list(op.inner_outputs) + + inner_sitsot_ins = op.inner_sitsot(inner_inputs) + inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs) + inner_untraced_ins = op.inner_untraced_sit_sot(inner_inputs) + inner_untraced_outs = op.inner_untraced_sit_sot_outs(inner_outputs) + + # Split sit_sot into remaining and converted + new_sit_sot_in_slices = [] + remaining_inner_sitsot_ins = [] + remaining_inner_sitsot_outs = [] + remaining_outer_sitsot = [] + converted_inner_untraced_ins = [] + converted_inner_untraced_outs = [] + converted_outer_untraced = [] + + for idx in range(info.n_sit_sot): + if idx in convertible_set: + converted_inner_untraced_ins.append(inner_sitsot_ins[idx]) + converted_inner_untraced_outs.append(inner_sitsot_outs[idx]) + converted_outer_untraced.append(outer_sitsot[idx][0]) + else: + new_sit_sot_in_slices.append(info.sit_sot_in_slices[idx]) + remaining_inner_sitsot_ins.append(inner_sitsot_ins[idx]) + remaining_inner_sitsot_outs.append(inner_sitsot_outs[idx]) + remaining_outer_sitsot.append(outer_sitsot[idx]) + + # Rebuild inner inputs: + # seqs | mit_mot_taps | mit_sot_taps | sit_sot | untraced_sit_sot | non_seqs + n_taps_before_sitsot = sum( + len(x) for x in chain(info.mit_mot_in_slices, info.mit_sot_in_slices) + ) + pre_sitsot_inner = inner_inputs[: info.n_seqs + n_taps_before_sitsot] + inner_non_seqs = op.inner_non_seqs(inner_inputs) + + new_inner_inputs = ( + pre_sitsot_inner + + remaining_inner_sitsot_ins + + list(inner_untraced_ins) + + converted_inner_untraced_ins + + inner_non_seqs + ) + + # Rebuild inner outputs: + # mit_mot_outs | mit_sot | sit_sot | nit_sot | untraced_sit_sot [| while_cond] + n_mit_mot_outs = sum(len(x) for x in info.mit_mot_out_slices) + pre_sitsot_inner_outs = inner_outputs[: n_mit_mot_outs + info.n_mit_sot] + nitsot_outs = op.inner_nitsot_outs(inner_outputs) + + new_inner_outputs = ( + pre_sitsot_inner_outs + + remaining_inner_sitsot_outs + + nitsot_outs + + list(inner_untraced_outs) + + converted_inner_untraced_outs + ) + if info.as_while: + new_inner_outputs.append(inner_outputs[-1]) + + # Rebuild outer inputs: + # n_steps | seqs | mit_mot | mit_sot | sit_sot | untraced_sit_sot | nit_sot | non_seqs + pre_sitsot_outer = list( + node.inputs[: 1 + info.n_seqs + info.n_mit_mot + info.n_mit_sot] + ) + outer_untraced = list(op.outer_untraced_sit_sot(node.inputs)) + outer_nitsot = list(op.outer_nitsot(node.inputs)) + outer_non_seqs = list(op.outer_non_seqs(node.inputs)) + + new_outer_inputs = ( + pre_sitsot_outer + + remaining_outer_sitsot + + outer_untraced + + converted_outer_untraced + + outer_nitsot + + outer_non_seqs + ) + + # Build new ScanInfo + new_info = dataclasses.replace( + info, + sit_sot_in_slices=tuple(new_sit_sot_in_slices), + n_untraced_sit_sot=info.n_untraced_sit_sot + len(convertible), + ) + + new_op = Scan( + new_inner_inputs, + new_inner_outputs, + new_info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + new_outs = cast(list[TensorVariable], new_op(*new_outer_inputs, return_list=True)) + + # Build replacement mapping + # Old outer outputs: mit_mot | mit_sot | sit_sot | nit_sot | untraced_sit_sot + # New outer outputs: mit_mot | mit_sot | remaining_sit_sot | nit_sot | old_untraced | converted_untraced + old_outputs = node.outputs + replacements: dict = {} + + # mit_mot + mit_sot: same relative positions + n_pre = info.n_mit_mot + info.n_mit_sot + for i in range(n_pre): + replacements[old_outputs[i]] = new_outs[i] + + # sit_sot: remaining keep position, converted become untraced + old_sitsot_offset = n_pre + new_remaining_offset = new_info.n_mit_mot + new_info.n_mit_sot + new_converted_offset = ( + new_info.n_mit_mot + + new_info.n_mit_sot + + new_info.n_sit_sot + + new_info.n_nit_sot + + info.n_untraced_sit_sot + ) + remaining_count = 0 + converted_count = 0 + for idx in range(info.n_sit_sot): + old_out = old_outputs[old_sitsot_offset + idx] + if idx in convertible_set: + new_untraced = new_outs[new_converted_offset + converted_count] + replacements[old_out] = pt.expand_dims(new_untraced, 0) + converted_count += 1 + else: + replacements[old_out] = new_outs[new_remaining_offset + remaining_count] + remaining_count += 1 + + # nit_sot + old_nitsot_offset = n_pre + info.n_sit_sot + new_nitsot_offset = new_info.n_mit_mot + new_info.n_mit_sot + new_info.n_sit_sot + for i in range(info.n_nit_sot): + replacements[old_outputs[old_nitsot_offset + i]] = new_outs[ + new_nitsot_offset + i + ] + + # Original untraced_sit_sot + old_untraced_offset = n_pre + info.n_sit_sot + info.n_nit_sot + new_untraced_offset = ( + new_info.n_mit_mot + + new_info.n_mit_sot + + new_info.n_sit_sot + + new_info.n_nit_sot + ) + for i in range(info.n_untraced_sit_sot): + replacements[old_outputs[old_untraced_offset + i]] = new_outs[ + new_untraced_offset + i + ] + + replacements["remove"] = [node] + return replacements + + class ScanMerge(GraphRewriter): r"""Graph optimizer that merges different scan ops. @@ -2524,6 +2709,14 @@ def apply(self, fgraph, start_from=None): use_db_name_as_tag=False, position=1.61, ) +# After scan_save_mem (it could be merged with it, but that rewrite is already a beast as is) +optdb.register( + "scan_sit_sot_to_untraced", + dfs_rewriter(scan_sit_sot_to_untraced, ignore_newtrees=True), + "fast_run", + "scan", + position=1.62, +) optdb.register( "scan_make_inplace", ScanInplaceOptimizer(), diff --git a/tests/benchmarks/test_scan.py b/tests/benchmarks/test_scan.py index c4944c6341..a040e3021c 100644 --- a/tests/benchmarks/test_scan.py +++ b/tests/benchmarks/test_scan.py @@ -194,6 +194,7 @@ def _test_sit_sot_buffer_benchmark( if buffer_size == "unit": xs_kept = xs[-1] expected_buffer_size = 1 + mode_preallocs_output + expected_untraced_sit_sot = not mode_preallocs_output elif buffer_size == "aligned": xs_kept = xs[-2:] expected_buffer_size = 2 @@ -217,8 +218,13 @@ def _test_sit_sot_buffer_benchmark( [scan_node] = [ node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan) ] - buffer = scan_node.inputs[1] - assert buffer.type.shape[0] == expected_buffer_size + if buffer_size == "unit" and expected_untraced_sit_sot: + # sit_sot was converted to untraced_sit_sot (no buffer dimension) + assert scan_node.op.info.n_sit_sot == 0 + assert scan_node.op.info.n_untraced_sit_sot == 1 + else: + buffer = scan_node.inputs[1] + assert buffer.type.shape[0] == expected_buffer_size benchmark(fn, x_test) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 905bcb010d..ae02f03dfa 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -373,7 +373,8 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)], mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)], ] - [sit_sot_inp] = scan_op.inner_sitsot(inner_inps) + sit_sot_inps = scan_op.inner_sitsot(inner_inps) + untraced_sit_sot_inps = scan_op.inner_untraced_sit_sot(inner_inps) destroyed_inputs = [] for inner_out in scan_op.fgraph.outputs: @@ -385,8 +386,12 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ) if n_steps_constant: + # The scalar sit_sot (x) is converted to untraced_sit_sot + # by the scan_sit_sot_to_untraced rewrite when only the last value is used + assert len(sit_sot_inps) == 0 + assert len(untraced_sit_sot_inps) == 1 assert len(destroyed_inputs) == 3 - assert set(destroyed_inputs) == {*oldest_mit_sot_inps, sit_sot_inp} + assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} else: # This is not a feature, but a current limitation # https://github.com/pymc-devs/pytensor/issues/1283 @@ -435,8 +440,13 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None): for node in numba_fn.maker.fgraph.toposort() if isinstance(node.op, Scan) ] - buffer = scan_node.inputs[1] - assert buffer.type.shape[0] == expected_buffer_size + if expected_buffer_size == 1: + # sit_sot_to_untraced converts unit-buffer sit_sot to untraced_sit_sot + assert scan_node.op.info.n_sit_sot == 0 + assert scan_node.op.info.n_untraced_sit_sot == 1 + else: + buffer = scan_node.inputs[1] + assert buffer.type.shape[0] == expected_buffer_size if benchmark is not None: numba_fn.trust_input = True diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index c578426e87..a89b77641b 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1828,6 +1828,54 @@ def test_broadcasted_init(self, keep_beginning, val_ndim): assert buffer_size_fn(val_test) == 52 if keep_beginning else 50 +def test_scan_sit_sot_to_untraced(): + """Test sit_sot to untraced_sit_sot conversion. + + 4 outputs: xs (sit_sot, all values used → stays), ys (sit_sot, only last + → converted), ws (nit_sot, unaffected), rs (sit_sot, required orphan + → converted). Result: 1 sit_sot, 1 nit_sot, 2 untraced_sit_sot. + """ + mode = ( + get_default_mode() + .excluding("scan_save_mem") + .including("scan_save_mem_no_prealloc", "scan_sit_sot_to_untraced") + ) + + x0 = vector("x0") + y0 = vector("y0") + r0 = vector("r0") + + def step(x_tm1, y_tm1, r_tm1): + r = 1.0 - x_tm1 + x = x_tm1 + 0.5 * r + 0.3 * r_tm1 + y = y_tm1 + 1 + w = x_tm1 * 2 + return x, y, w, r + + [xs, ys, ws, _rs] = scan( + step, outputs_info=[x0, y0, None, r0], n_steps=10, return_updates=False + ) + # xs: all values used (stays sit_sot) + # ys[-1]: only last value (converted) + # ws[-1]: nit_sot (unaffected) + # rs: never used externally, required orphan (converted) + f = function([x0, y0, r0], [xs, ys[-1], ws[-1]], mode=mode) + + [scan_node] = [n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)] + assert scan_node.op.info.n_sit_sot == 1 + assert scan_node.op.info.n_nit_sot == 1 + assert scan_node.op.info.n_untraced_sit_sot == 2 + + x0_val = np.zeros(3, dtype=config.floatX) + y0_val = np.zeros(3, dtype=config.floatX) + r0_val = np.zeros(3, dtype=config.floatX) + res_xs, res_y, res_w = f(x0_val, y0_val, r0_val) + np.testing.assert_allclose(res_y, y0_val + 10) + assert res_xs.shape == (10, 3) + assert np.all(np.isfinite(res_xs)) + assert np.isfinite(res_w).all() + + def test_inner_replace_dot(): """ This tests that rewrites are applied to the inner-graph. From 4747a65af8a250b3cc8598bb3ab61c6a75c6a6e0 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 28 Mar 2026 02:38:14 +0100 Subject: [PATCH 2/3] Generalize useless slice canonicalization to other Subtensor Ops Also make local_useless_inc_subtensor a bit more powerful. It's not really useless because it may still need to broadcast or reverse arrays --- pytensor/tensor/rewriting/subtensor.py | 158 ++++++++++-------- tests/link/numba/test_scan.py | 9 +- tests/tensor/rewriting/test_elemwise.py | 8 + tests/tensor/rewriting/test_subtensor.py | 50 ++++-- tests/tensor/rewriting/test_subtensor_lift.py | 6 +- tests/tensor/test_basic.py | 2 +- 6 files changed, 142 insertions(+), 91 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 8af51fc992..7b49644075 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -79,6 +79,7 @@ advanced_subtensor1, as_index_constant, basic_subtensor, + flatten_index_variables, get_canonical_form_slice, get_constant_idx, get_idx_list, @@ -271,29 +272,32 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): @register_canonicalize @register_specialize @register_stabilize -@node_rewriter([Subtensor]) +@node_rewriter([Subtensor, IncSubtensor, AdvancedSubtensor, AdvancedIncSubtensor]) def local_useless_slice(fgraph, node): - """ - Remove useless slice(None) of the form: - 1. X[0, :] -> X[0] - 2. X[:] -> X + """Remove useless slices and canonicalize redundant slice bounds to ``None``. - Also, canonicalize slices of the form: - X[0:7:1] -> X[None:None:None] - where X is a vector of length 7 - - And: - X[-1:-8:-1] -> X[::-1] - where x is a vector of length 7 + Applies to all Subtensor Ops with slices (basic and advanced, get and set). + - ``X[0, :]`` → ``X[0]`` (trailing full slices dropped) + - ``X[:]`` → ``X`` + - ``X[0:7:1]`` → ``X[:]`` when ``X.shape[0] <= 7`` + - ``X[-1:-8:-1]`` → ``X[::-1]`` when ``X.shape[0] <= 7`` """ - idxs = get_idx_list(node.inputs, node.op.idx_list) - x = node.inputs[0] + op = node.op + idx_list = op.idx_list + if not idx_list: + if isinstance(op, Subtensor | AdvancedSubtensor): + return [node.inputs[0]] + else: + # We let local_useless_inc_subtensor handle these + return None - if not idxs: - return [node.inputs[0]] + if is_inc_subtensor := isinstance(op, IncSubtensor | AdvancedIncSubtensor): + x, y, *idx_vars = node.inputs + else: + x, *idx_vars = node.inputs - new_idxs = list(idxs) + new_idxs = list(indices_from_subtensor(idx_vars, idx_list)) change_flag = False last_useful_idx = -1 for dim, s in enumerate(new_idxs): @@ -322,32 +326,53 @@ def local_useless_slice(fgraph, node): start = s.start stop = s.stop - if start is not None and get_scalar_constant_value( - start, only_process_constants=True, raise_not_constant=False - ) == (0 if positive_step else -1): - change_flag = True - start = None - - if ( - stop is not None - and x.type.shape[dim] is not None - and get_scalar_constant_value( - stop, only_process_constants=True, raise_not_constant=False - ) - == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) - ): - change_flag = True - stop = None + dim_length = x.type.shape[dim] if dim < x.type.ndim else None + if start is not None and isinstance(start, Constant): + start_val = start.data + if positive_step: + if ( + start_val == 0 + # Negative start that wraps to or before index 0 + or (dim_length is not None and -start_val >= dim_length) + ): + change_flag = True + start = None + else: + if ( + start_val == -1 + # Positive start at or beyond the last index + or (dim_length is not None and start_val >= dim_length - 1) + ): + change_flag = True + start = None + + if dim_length is not None and stop is not None and isinstance(stop, Constant): + stop_val = stop.data + if positive_step: + # Positive stop at or beyond the length + if stop_val >= dim_length: + change_flag = True + stop = None + else: + # Negative stop that wraps to or before index 0 + if -stop_val > dim_length: + change_flag = True + stop = None if start is not None or stop is not None or step is not None: last_useful_idx = dim new_idxs[dim] = slice(start, stop, step) - if change_flag or ((last_useful_idx + 1) < len(idxs)): - new_idxs = tuple(new_idxs[: last_useful_idx + 1]) - out = x[new_idxs] if new_idxs else x - # Copy over previous output stacktrace + if change_flag or (last_useful_idx + 1) < len(idx_list): + new_idxs = new_idxs[: last_useful_idx + 1] + new_idx_list, new_flat_vars = flatten_index_variables(new_idxs) + props = op._props_dict() | {"idx_list": new_idx_list} + if is_inc_subtensor: + # We let local_useless_inc_subtensor handle empty new_idx_list + out = type(op)(**props)(x, y, *new_flat_vars) + else: + out = type(op)(**props)(x, *new_flat_vars) if new_idx_list else x copy_stack_trace(node.outputs, out) return [out] @@ -515,26 +540,17 @@ def local_subtensor_inc_subtensor(fgraph, node): return -@register_useless @register_canonicalize @register_specialize -@node_rewriter([IncSubtensor]) +@node_rewriter([IncSubtensor, AdvancedIncSubtensor]) def local_useless_inc_subtensor(fgraph, node): r"""Remove redundant `IncSubtensor`\s. - More specifically, ``set_subtensor(x[indices], y)`` is replaced by - ``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is - equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced - by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices`` - are full slices, and the shapes are equal. + Replace set_subtensor (or inc_subtensor on zero) that overwrite their whole buffers + by the written value (perhaps broadcasted and/or reversed). """ - if not isinstance(node.op, IncSubtensor): - return - - if not hasattr(fgraph, "shape_feature"): - return - x, y, *index_inputs = node.inputs + x, y, *index_vars = node.inputs if node.op.set_instead_of_inc is False: # This is an increment operation, so the array being incremented must @@ -546,12 +562,9 @@ def local_useless_inc_subtensor(fgraph, node): except NotScalarConstantError: return - idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list) + indices = indices_from_subtensor(index_vars, node.op.idx_list) - # Check that all indices are full slices with only reversals and no step - # sizes - # TODO: It seems like there should be a basic `IncSubtensor` - # canonicalization that removes these redundant slices. + # Check that all indices are full slices or full reversals if all( isinstance(e, slice) and e.start is None @@ -563,23 +576,32 @@ def local_useless_inc_subtensor(fgraph, node): ) == -1 ) - for e in idx_cst + for e in indices ): - # `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we - # must check that they are the same - if not fgraph.shape_feature.same_shape(x, y): - return + # IncSubtensor casts y to x's dtype and broadcasts y onto x's shape + out_dtype = node.outputs[0].type.dtype - # There are no reversals, so we don't need a replacement. - if all(e.step is None for e in node.op.idx_list): - # They are exactly the same shapes, so we can remove this `IncSubtensor` - return [y] + # Check shapes before casting, as cast creates a new node not in the fgraph + static_same = x.type.shape == y.type.shape and all( + s is not None for s in x.type.shape + ) + if not static_same: + if hasattr(fgraph, "shape_feature") and fgraph.shape_feature.same_shape( + x, y + ): + static_same = True - new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs) - new_out = new_node.outputs[0] - copy_stack_trace(node.outputs, new_out) + if y.type.dtype != out_dtype: + y = cast(y, out_dtype) - return [new_out] + if not static_same: + y = alloc(y, *x.shape) + copy_stack_trace(node.outputs[0], y) + + if not all(e.step is None for e in node.op.idx_list): + y = Subtensor(node.op.idx_list)(y, *index_vars) + + return [y] @register_canonicalize diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index ae02f03dfa..a4b15dc072 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -387,11 +387,14 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): if n_steps_constant: # The scalar sit_sot (x) is converted to untraced_sit_sot - # by the scan_sit_sot_to_untraced rewrite when only the last value is used + # by the scan_sit_sot_to_untraced rewrite when only the last value is used. + # With constant n_steps, scan_save_mem + local_useless_slice strip the + # AllocEmpty buffers, so inputs become raw function inputs that can't + # be inplaced. assert len(sit_sot_inps) == 0 assert len(untraced_sit_sot_inps) == 1 - assert len(destroyed_inputs) == 3 - assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} + assert len(destroyed_inputs) == 2 + assert set(destroyed_inputs) == set(oldest_mit_sot_inps) else: # This is not a feature, but a current limitation # https://github.com/pymc-devs/pytensor/issues/1283 diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7a729cac90..6dbaa75b85 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1095,6 +1095,14 @@ def test_fusion_35_inputs(self): assert not any(len(node.inputs) > 31 for node in composite_nodes) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") + @pytest.mark.xfail( + reason="Elemwise.perform doesn't support >32 operands. " + "local_useless_inc_subtensor now triggers get_underlying_scalar_constant_value " + "on large fused Add nodes, exposing a pre-existing bug where prepare_node " + "refuses to create a ufunc for >32 inputs but perform falls through to the " + "ufunc code path anyway (missing return after super().perform()).", + raises=AttributeError, + ) def test_big_fusion(self): # Make sure that C compilation is used mode = Mode("cvm", self.rewrites) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 8c50362840..7c4427fb02 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -129,28 +129,42 @@ def test_local_useless_inc_subtensor(s): x = matrix("x") y = matrix("y") - o = set_subtensor(x[:, s], y) - mode = get_default_mode().including("local_useless_inc_subtensor") - # Test without shape info (i.e. don't apply the opt) + # Without shape info: rewrite fires but inserts alloc to handle broadcast + o = set_subtensor(x[:, s], y) f = function([x, y], o, mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, IncSubtensor) + assert not any(isinstance(n.op, IncSubtensor) for n in topo) + out = f([[2, 3]], [[3, 4]]) + assert np.array_equal(out, np.asarray([[3, 4]])[::, s]) - # Test with shape info + # With shape info: rewrite fires without alloc o_shape = set_subtensor(x[:, s], specify_shape(y, x.shape)) f_shape = function([x, y], o_shape, mode=mode) - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) - + assert not any(isinstance(n.op, IncSubtensor | Alloc) for n in topo) out = f_shape([[2, 3]], [[3, 4]]) assert np.array_equal(out, np.asarray([[3, 4]])[::, s]) +def test_local_useless_setsubtensor_alloc_empty(): + """SetSubtensor(AllocEmpty(1, n), y, :1) -> y when y fills the buffer. + + This is the pattern produced by scan_save_mem for sit_sot with buffer=1. + local_useless_slice canonicalizes [:1] to [:] on the size-1 dim, + then local_useless_inc_subtensor removes the SetSubtensor entirely. + """ + from pytensor.graph.rewriting.utils import rewrite_graph + from pytensor.tensor.basic import AllocEmpty + + y = matrix("y", shape=(1, 5)) + x = AllocEmpty("float64")(pt.constant(1), pt.constant(5)) + buf = set_subtensor(x[:1], y) + result = rewrite_graph(buf, include=("specialize",)) + utt.assert_equal_computations([result], [y], original=[buf]) + + def test_local_useless_inc_subtensor_increment_zeros(): r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays.""" y = matrix("y") @@ -1335,7 +1349,7 @@ def test_incsubtensor_allocs1(self): def test_incsubtensor_x_zeros(self): x = pt.constant(np.asarray(np.zeros((4, 4)), dtype=config.floatX)) y = matrix() - z = inc_subtensor(x[:4], y) + z = inc_subtensor(x[:3], y) f = function([y], z) inc_nodes = [ n for n in f.maker.fgraph.toposort() if isinstance(n.op, IncSubtensor) @@ -1344,23 +1358,27 @@ def test_incsubtensor_x_zeros(self): assert len(inc_nodes) == 1 node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc assert node_is_set_instead_of_inc - test_X = np.random.random((4, 4)).astype(config.floatX) - utt.assert_allclose(f(test_X), test_X) + test_y = np.random.random((3, 4)).astype(config.floatX) + expected = np.zeros((4, 4), dtype=config.floatX) + expected[:3] += test_y + utt.assert_allclose(f(test_y), expected) # also check the flag doesn't get set if first input is not zeros: not_all_zeros = np.zeros((4, 4)) not_all_zeros[1, 0] = 0.001 x = pt.constant(np.asarray(not_all_zeros, dtype=config.floatX)) y = matrix() - z = inc_subtensor(x[:4], y) + z = inc_subtensor(x[:3], y) f = function([y], z) inc_nodes = [ n for n in f.maker.fgraph.toposort() if isinstance(n.op, IncSubtensor) ] assert len(inc_nodes) == 1 assert inc_nodes[0].op.set_instead_of_inc is False - test_X = np.random.random((4, 4)).astype(config.floatX) - utt.assert_allclose(f(test_X), test_X + not_all_zeros) + test_y = np.random.random((3, 4)).astype(config.floatX) + expected = not_all_zeros.copy() + expected[:3] += test_y + utt.assert_allclose(f(test_y), expected) def test_advancedincsubtensor1_allocs0(self): x = matrix() diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 6885c47278..40c2e49645 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -309,10 +309,10 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn): (lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]), (lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]), ( - lambda x: softmax(x, axis=0)[0, :5:2], - lambda x: softmax(x[:, :5:2], axis=0)[0], + lambda x: softmax(x, axis=0)[0, :2:2], + lambda x: softmax(x[:, :2:2], axis=0)[0], ), - (lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]), + (lambda x: softmax(x, axis=1)[0, :2:2], lambda x: softmax(x[0], axis=0)[:2:2]), ], ) def test_local_subtensor_of_softmax(original_fn, expected_fn): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index ce29321d0d..a5a3f941c0 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -741,7 +741,7 @@ def setup_method(self): "subtensor_fn, expected_grad_n_alloc", [ # IncSubtensor1 - (lambda x: x[:60], 1), + (lambda x: x[:59], 1), # AdvancedIncSubtensor1 (lambda x: x[np.arange(60)], 1), # AdvancedIncSubtensor From 88c2f53cedac77df41b8fec1542811d074ffd7bb Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 28 Mar 2026 18:07:37 +0100 Subject: [PATCH 3/3] Numba scan: Always try to inplace on untraced_sit_sot --- pytensor/link/numba/dispatch/scan.py | 48 +++++++++++++++++++--------- tests/link/numba/test_scan.py | 8 +---- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 2f9901731e..312114344e 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -14,6 +14,7 @@ numba_funcify_and_cache_key, register_funcify_and_cache_key, ) +from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.string_codegen import create_tuple_string from pytensor.scan.op import Scan @@ -65,7 +66,6 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): .excluding(*NUMBA._optimizer.exclude) .optimizer ) - destroy_map = op.destroy_map fgraph = op.fgraph # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, # We must always discard the oldest tap, so it's safe to destroy it in the inner function. @@ -87,16 +87,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ) if outer_mitsot.type.shape[0] == abs(min(taps)) ] - # Untraced sit_sot or destroyable if on destroy_map - destroyable_untraced_sit_sot = [ - inner_u_sit_sot - for (outer_u_sit_sot_idx, _), inner_u_sit_sot in zip( - op.outer_untraced_sit_sot_outs(node.inputs, with_idx=True), - op.inner_untraced_sit_sot(fgraph.inputs), - strict=True, - ) - if outer_u_sit_sot_idx in destroy_map - ] + # Always allow the inner function to destroy untraced_sit_sot inputs. + # After the first iteration, these come from the previous output so + # destroying is always safe. For the first iteration, the codegen + # copies the outer input if the Scan's destroy_map doesn't allow it. + destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs)) destroyable = { *destroyable_sitsot, *destroyable_mitsot, @@ -115,6 +110,17 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + # Track which untraced_sit_sot outputs have their inner input destroyed + # by the optimized inner function (transitively, via DestroyHandler). + untraced_start = ( + op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot + ) + inner_destroyed_untraced_out_idxs = set() + if hasattr(fgraph, "destroyers"): + for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)): + if fgraph.destroyers(inner_inp): + inner_destroyed_untraced_out_idxs.add(untraced_start + j) + scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( op.fgraph, fgraph_name="numba_scan" ) @@ -351,10 +357,17 @@ def add_output_storage_post_proc_stmt( inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) - if output_idx in node.op.destroy_map or not is_tapped: - storage_alloc_stmt = f"{storage_name} = {outer_in_name}" + # Copy the outer input when it will be mutated during the loop + # but the Scan's destroy_map doesn't grant ownership. + # Tapped outputs: the loop writes into the buffer via circular indexing. + # Untraced sit_sot: the inner function may destroy the input inplace. + needs_copy = output_idx not in node.op.destroy_map and ( + is_tapped or output_idx in inner_destroyed_untraced_out_idxs + ) + if needs_copy: + storage_alloc_stmt = f"{storage_name} = numba_deepcopy({outer_in_name})" else: - storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" + storage_alloc_stmt = f"{storage_name} = {outer_in_name}" storage_alloc_stmt = dedent( f""" @@ -471,7 +484,12 @@ def scan({", ".join(outer_in_names)}): scan_op_fn = compile_numba_function_src( scan_op_src, "scan", - globals() | {"np": np, "scan_inner_func": scan_inner_func}, + globals() + | { + "np": np, + "scan_inner_func": scan_inner_func, + "numba_deepcopy": numba_deepcopy, + }, ) if inner_func_cache_key is None: diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index a4b15dc072..e6c95e3bee 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -386,15 +386,9 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ) if n_steps_constant: - # The scalar sit_sot (x) is converted to untraced_sit_sot - # by the scan_sit_sot_to_untraced rewrite when only the last value is used. - # With constant n_steps, scan_save_mem + local_useless_slice strip the - # AllocEmpty buffers, so inputs become raw function inputs that can't - # be inplaced. assert len(sit_sot_inps) == 0 assert len(untraced_sit_sot_inps) == 1 - assert len(destroyed_inputs) == 2 - assert set(destroyed_inputs) == set(oldest_mit_sot_inps) + assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} else: # This is not a feature, but a current limitation # https://github.com/pymc-devs/pytensor/issues/1283