Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.scan.op import Scan
from pytensor.tensor.type import TensorType


def idx_to_str(
Expand Down Expand Up @@ -66,7 +66,6 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
destroy_map = op.destroy_map
fgraph = op.fgraph
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
Expand All @@ -88,16 +87,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
if outer_mitsot.type.shape[0] == abs(min(taps))
]
# Untraced sit_sot or destroyable if on destroy_map
destroyable_untraced_sit_sot = [
inner_u_sit_sot
for (outer_u_sit_sot_idx, _), inner_u_sit_sot in zip(
op.outer_untraced_sit_sot_outs(node.inputs, with_idx=True),
op.inner_untraced_sit_sot(fgraph.inputs),
strict=True,
)
if outer_u_sit_sot_idx in destroy_map
]
# Always allow the inner function to destroy untraced_sit_sot inputs.
# After the first iteration, these come from the previous output so
# destroying is always safe. For the first iteration, the codegen
# copies the outer input if the Scan's destroy_map doesn't allow it.
destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs))
destroyable = {
*destroyable_sitsot,
*destroyable_mitsot,
Expand All @@ -116,6 +110,17 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)

# Track which untraced_sit_sot outputs have their inner input destroyed
# by the optimized inner function (transitively, via DestroyHandler).
untraced_start = (
op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot
)
inner_destroyed_untraced_out_idxs = set()
if hasattr(fgraph, "destroyers"):
for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)):
if fgraph.destroyers(inner_inp):
inner_destroyed_untraced_out_idxs.add(untraced_start + j)

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan"
)
Expand Down Expand Up @@ -308,8 +313,8 @@ def add_output_storage_post_proc_stmt(
if outer_in_name not in outer_in_nit_sot_names:
storage_name = outer_in_to_storage_name[outer_in_name]

is_tensor_type = isinstance(outer_in_var.type, TensorType)
if is_tensor_type:
is_tapped = outer_in_name not in outer_in_untraced_sit_sot_names
if is_tapped:
storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name]
Expand Down Expand Up @@ -352,10 +357,17 @@ def add_output_storage_post_proc_stmt(
inner_out_to_outer_in_stmts.append(storage_name)

output_idx = outer_output_names.index(storage_name)
if output_idx in node.op.destroy_map or not is_tensor_type:
storage_alloc_stmt = f"{storage_name} = {outer_in_name}"
# Copy the outer input when it will be mutated during the loop
# but the Scan's destroy_map doesn't grant ownership.
# Tapped outputs: the loop writes into the buffer via circular indexing.
# Untraced sit_sot: the inner function may destroy the input inplace.
needs_copy = output_idx not in node.op.destroy_map and (
is_tapped or output_idx in inner_destroyed_untraced_out_idxs
)
if needs_copy:
storage_alloc_stmt = f"{storage_name} = numba_deepcopy({outer_in_name})"
else:
storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})"
storage_alloc_stmt = f"{storage_name} = {outer_in_name}"

storage_alloc_stmt = dedent(
f"""
Expand Down Expand Up @@ -472,7 +484,12 @@ def scan({", ".join(outer_in_names)}):
scan_op_fn = compile_numba_function_src(
scan_op_src,
"scan",
globals() | {"np": np, "scan_inner_func": scan_inner_func},
globals()
| {
"np": np,
"scan_inner_func": scan_inner_func,
"numba_deepcopy": numba_deepcopy,
},
)

if inner_func_cache_key is None:
Expand Down
195 changes: 194 additions & 1 deletion pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:

# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input, taps):
extra_size = 1 if required_orphan else val - taps
if required_orphan:
extra_size = (
1 if backend_supports_output_pre_allocation else 0
)
else:
extra_size = val - taps
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
Expand Down Expand Up @@ -1766,6 +1771,186 @@ def scan_save_mem_no_prealloc(fgraph, node):
)


@node_rewriter([Scan])
def scan_sit_sot_to_untraced(fgraph, node):
"""Convert sit_sot with buffer size=1 to untraced_sit_sot.

After scan_save_mem has reduced buffer sizes, sit_sot outputs that only
need one state stored (buffer size=1) can be converted to untraced_sit_sot,
which avoids the overhead of reading/writing circular buffers each iteration.
"""
op = node.op
info = op.info

if info.n_sit_sot == 0:
return False

outer_sitsot = op.outer_sitsot(node.inputs)
convertible = [
idx for idx in range(info.n_sit_sot) if outer_sitsot[idx].type.shape[0] == 1
]

if not convertible:
return False

convertible_set = set(convertible)

# Gather current inner inputs/outputs by category
inner_inputs = list(op.inner_inputs)
inner_outputs = list(op.inner_outputs)

inner_sitsot_ins = op.inner_sitsot(inner_inputs)
inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs)
inner_untraced_ins = op.inner_untraced_sit_sot(inner_inputs)
inner_untraced_outs = op.inner_untraced_sit_sot_outs(inner_outputs)

# Split sit_sot into remaining and converted
new_sit_sot_in_slices = []
remaining_inner_sitsot_ins = []
remaining_inner_sitsot_outs = []
remaining_outer_sitsot = []
converted_inner_untraced_ins = []
converted_inner_untraced_outs = []
converted_outer_untraced = []

for idx in range(info.n_sit_sot):
if idx in convertible_set:
converted_inner_untraced_ins.append(inner_sitsot_ins[idx])
converted_inner_untraced_outs.append(inner_sitsot_outs[idx])
converted_outer_untraced.append(outer_sitsot[idx][0])
else:
new_sit_sot_in_slices.append(info.sit_sot_in_slices[idx])
remaining_inner_sitsot_ins.append(inner_sitsot_ins[idx])
remaining_inner_sitsot_outs.append(inner_sitsot_outs[idx])
remaining_outer_sitsot.append(outer_sitsot[idx])

# Rebuild inner inputs:
# seqs | mit_mot_taps | mit_sot_taps | sit_sot | untraced_sit_sot | non_seqs
n_taps_before_sitsot = sum(
len(x) for x in chain(info.mit_mot_in_slices, info.mit_sot_in_slices)
)
pre_sitsot_inner = inner_inputs[: info.n_seqs + n_taps_before_sitsot]
inner_non_seqs = op.inner_non_seqs(inner_inputs)

new_inner_inputs = (
pre_sitsot_inner
+ remaining_inner_sitsot_ins
+ list(inner_untraced_ins)
+ converted_inner_untraced_ins
+ inner_non_seqs
)

# Rebuild inner outputs:
# mit_mot_outs | mit_sot | sit_sot | nit_sot | untraced_sit_sot [| while_cond]
n_mit_mot_outs = sum(len(x) for x in info.mit_mot_out_slices)
pre_sitsot_inner_outs = inner_outputs[: n_mit_mot_outs + info.n_mit_sot]
nitsot_outs = op.inner_nitsot_outs(inner_outputs)

new_inner_outputs = (
pre_sitsot_inner_outs
+ remaining_inner_sitsot_outs
+ nitsot_outs
+ list(inner_untraced_outs)
+ converted_inner_untraced_outs
)
if info.as_while:
new_inner_outputs.append(inner_outputs[-1])

# Rebuild outer inputs:
# n_steps | seqs | mit_mot | mit_sot | sit_sot | untraced_sit_sot | nit_sot | non_seqs
pre_sitsot_outer = list(
node.inputs[: 1 + info.n_seqs + info.n_mit_mot + info.n_mit_sot]
)
outer_untraced = list(op.outer_untraced_sit_sot(node.inputs))
outer_nitsot = list(op.outer_nitsot(node.inputs))
outer_non_seqs = list(op.outer_non_seqs(node.inputs))

new_outer_inputs = (
pre_sitsot_outer
+ remaining_outer_sitsot
+ outer_untraced
+ converted_outer_untraced
+ outer_nitsot
+ outer_non_seqs
)

# Build new ScanInfo
new_info = dataclasses.replace(
info,
sit_sot_in_slices=tuple(new_sit_sot_in_slices),
n_untraced_sit_sot=info.n_untraced_sit_sot + len(convertible),
)

new_op = Scan(
new_inner_inputs,
new_inner_outputs,
new_info,
mode=op.mode,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
name=op.name,
allow_gc=op.allow_gc,
)
new_outs = cast(list[TensorVariable], new_op(*new_outer_inputs, return_list=True))

# Build replacement mapping
# Old outer outputs: mit_mot | mit_sot | sit_sot | nit_sot | untraced_sit_sot
# New outer outputs: mit_mot | mit_sot | remaining_sit_sot | nit_sot | old_untraced | converted_untraced
old_outputs = node.outputs
replacements: dict = {}

# mit_mot + mit_sot: same relative positions
n_pre = info.n_mit_mot + info.n_mit_sot
for i in range(n_pre):
replacements[old_outputs[i]] = new_outs[i]

# sit_sot: remaining keep position, converted become untraced
old_sitsot_offset = n_pre
new_remaining_offset = new_info.n_mit_mot + new_info.n_mit_sot
new_converted_offset = (
new_info.n_mit_mot
+ new_info.n_mit_sot
+ new_info.n_sit_sot
+ new_info.n_nit_sot
+ info.n_untraced_sit_sot
)
remaining_count = 0
converted_count = 0
for idx in range(info.n_sit_sot):
old_out = old_outputs[old_sitsot_offset + idx]
if idx in convertible_set:
new_untraced = new_outs[new_converted_offset + converted_count]
replacements[old_out] = pt.expand_dims(new_untraced, 0)
converted_count += 1
else:
replacements[old_out] = new_outs[new_remaining_offset + remaining_count]
remaining_count += 1

# nit_sot
old_nitsot_offset = n_pre + info.n_sit_sot
new_nitsot_offset = new_info.n_mit_mot + new_info.n_mit_sot + new_info.n_sit_sot
for i in range(info.n_nit_sot):
replacements[old_outputs[old_nitsot_offset + i]] = new_outs[
new_nitsot_offset + i
]

# Original untraced_sit_sot
old_untraced_offset = n_pre + info.n_sit_sot + info.n_nit_sot
new_untraced_offset = (
new_info.n_mit_mot
+ new_info.n_mit_sot
+ new_info.n_sit_sot
+ new_info.n_nit_sot
)
for i in range(info.n_untraced_sit_sot):
replacements[old_outputs[old_untraced_offset + i]] = new_outs[
new_untraced_offset + i
]

replacements["remove"] = [node]
return replacements


class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops.

Expand Down Expand Up @@ -2524,6 +2709,14 @@ def apply(self, fgraph, start_from=None):
use_db_name_as_tag=False,
position=1.61,
)
# After scan_save_mem (it could be merged with it, but that rewrite is already a beast as is)
optdb.register(
"scan_sit_sot_to_untraced",
dfs_rewriter(scan_sit_sot_to_untraced, ignore_newtrees=True),
"fast_run",
"scan",
position=1.62,
)
optdb.register(
"scan_make_inplace",
ScanInplaceOptimizer(),
Expand Down
Loading
Loading