From 96982c2c2c3f5f9e076bb87647647776e930beb3 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 15 Mar 2026 18:17:50 +0100 Subject: [PATCH 1/7] Fix Scan higher order derivatives --- pytensor/scan/op.py | 30 ++++++++++++++++++++++++++++++ tests/scan/test_basic.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 553c538296..6037332f65 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2687,7 +2687,37 @@ def compute_all_gradients(known_grads): else: dC_dXtm1s.append(safe_new(x)) + # Skip accumulation for "overlapping" mit-mot taps. + # + # A mit-mot tap "overlaps" when the same tap index appears in both the input + # and output slices of a single mit-mot state. This means the output *overwrites* + # the input at that buffer position — analogous to set_subtensor(x, y, i). + # + # The gradient of an overwrite must zero out the direct pass-through from the + # old value; the only gradient path is through the output expression that replaced + # it (already captured by compute_all_gradients via known_grads). + # + # The gradient for an overlapping tap is NOT zero — the chain-rule contribution + # through the output expression remains. We only skip the extra dC_dXtm1 + # accumulation term, which would incorrectly treat the old value as if it also + # passes through unchanged to future steps, double-counting the gradient. + # + # Overlapping taps arise naturally when differentiating sit-sot or mit-sot scans: + # their L_op converts the recurrence into a mit-mot where one tap serves as both + # read and write (e.g. in_taps=(0,1), out_taps=(1,) — tap 1 overlaps). + overlapping_taps = set() + dx_offset = 0 + for idx in range(info.n_mit_mot): + in_taps = info.mit_mot_in_slices[idx] + out_taps = info.mit_mot_out_slices[idx] + for k, tap in enumerate(in_taps): + if tap in out_taps: + overlapping_taps.add(dx_offset + k) + dx_offset += len(in_taps) + for dx, dC_dXtm1 in enumerate(dC_dXtm1s): + if dx in overlapping_taps: + continue # gradient truncates here if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType): # The accumulated gradient is undefined pass diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index f121dc9e58..318abb363a 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -1332,6 +1332,46 @@ def inner_fct(mitsot_m2, mitsot_m1, sitsot): sum_of_grads = sum(g.sum() for g in gradients) grad(sum_of_grads, inputs[0]) + def test_high_order_grad_sitsot(self): + """Test higher-order derivatives through a sit-sot scan. + + The L_op of a sit-sot scan creates a mit-mot backward scan where + one buffer position is both read and written. + This is analogous to set_subtensor(x, y, i): the gradient w.r.t. x + must zero out position i, routing gradient only through y. + + A bug in the accumulation logic added a spurious gradient at + the overwritten position, as if the old value also passed + through unchanged. The 2nd derivative graph was wrong but + evaluated correctly (the spurious contribution only affected + the mit-mot output, which is not on the gradient path for + scalar derivatives). The error became visible at the 3rd + derivative, where symbolic differentiation through the wrong + graph produced incorrect values. + """ + # Avoid costly rewrite/compilation of Scans + mode = Mode(linker="py", optimizer=None) + x = pt.scalar("x") + x_val = np.float64(0.95) + ys = scan( + fn=lambda xtm1: xtm1**2, outputs_info=[x], n_steps=4, return_updates=False + ) + y = ys[-1] + + # Sanity check + np.testing.assert_allclose(y.eval({x: x_val}, mode=mode), x_val**16) + + # Evaluate higher order derivatives + deriv = y + for order in range(1, 5): + deriv = grad(deriv, x) + deriv_value = deriv.eval({x: x_val}, mode=mode) + # xs[-1] = x^16, so the n-th derivative is 16!/(16-n)! * x^(16-n) + expected_deriv_value = np.prod((16, 15, 14, 13)[:order]) * x_val ** ( + 16 - order + ) + np.testing.assert_allclose(deriv_value, expected_deriv_value) + def test_grad_dtype_change(self): x = fscalar("x") y = fscalar("y") From 1e86a3a58bfb063f445f36fc6bdb559b9273a1c7 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 17 Mar 2026 00:46:13 +0100 Subject: [PATCH 2/7] Comment Scan.L_op --- pytensor/scan/op.py | 128 ++++++++++++++++++------------------- pytensor/scan/rewriting.py | 6 ++ pytensor/scan/utils.py | 21 ++++-- 3 files changed, 84 insertions(+), 71 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 6037332f65..d678b1850e 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2437,12 +2437,19 @@ def connection_pattern(self, node): return connection_pattern def L_op(self, inputs, outs, dC_douts): - # `grad_step` equals the number of steps the original scan node has - # done (if the original scan is a while loop than this number is the - # length of the output sequence) - # We do not know what kind of outputs the original scan has, so we - # try first to see if it has a nit_sot output, then a sit_sot and - # then a mit_sot + # Computes the gradient of this Scan by constructing a new backward Scan + # that runs in reverse. The method: + # 1. Differentiates the inner function symbolically (compute_all_gradients) + # 2. Adds accumulation terms for state inputs at preserved buffer positions + # 3. Builds reversed sequences from the forward outputs + # 4. Converts all recurrent states (sit-sot, mit-sot, mit-mot) into mit-mot + # form in the backward scan (initialized with output gradients, accumulate + # total gradients after evaluation) + # 5. Constructs and runs the backward Scan, then re-orders its outputs + + # Determine the number of gradient steps from the output shapes (not from + # inputs[0] directly, because while-loop scans may execute fewer steps than + # the allocated buffer size). info = self.info if info.n_nit_sot > 0: grad_steps = self.outer_nitsot_outs(outs)[0].shape[0] @@ -2457,8 +2464,7 @@ def L_op(self, inputs, outs, dC_douts): if info.as_while: n_steps = outs[0].shape[0] - # Restrict the number of grad steps according to - # self.truncate_gradient + # Restrict the number of grad steps according to self.truncate_gradient if self.truncate_gradient != -1: grad_steps = minimum(grad_steps, self.truncate_gradient) @@ -2540,13 +2546,11 @@ def compute_all_gradients(known_grads): ] gmp = {} - # Required in case there is a pair of variables X and Y, with X - # used to compute Y, for both of which there is an external - # gradient signal. Without this, the total gradient signal on X - # will be the external gradient signalknown_grads[X]. With this, - # it will be the sum of the external gradient signal and the - # gradient obtained by propagating Y's external gradient signal - # to X. + # The .copy() creates fresh variable nodes so that grad() treats them + # as new outputs "equal to" the originals, rather than matching them by + # identity to variables already in the graph. This forces grad() to + # propagate the known_grads values backward through the computation + # instead of short-circuiting at a wrt target. known_grads = {k.copy(): v for (k, v) in known_grads.items()} grads = grad( @@ -2588,17 +2592,15 @@ def compute_all_gradients(known_grads): Xt_placeholder = safe_new(Xt) Xts.append(Xt_placeholder) - # Different processing based on whether Xt is a nitsot output - # or not. NOTE : This cannot be done by using - # "if Xt not in self.inner_nitsot_outs(self_outputs)" because - # the exact same variable can be used as multiple outputs. + # Different processing based on whether Xt is a nitsot output or not. + # NOTE : This cannot be done by using "if Xt not in self.inner_nitsot_outs(self_outputs)" + # because the exact same variable can be used as multiple outputs. if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end: - # What we do here is loop through dC_douts and collect all + # loop through dC_douts and collect all # those that are connected to the specific one and do an # upcast on all of their dtypes to get the dtype for this # specific output. Deciding if the gradient with this - # specific previous step is defined or not is done somewhere - # else. + # specific previous step is defined or not is done somewhere else. dtypes = [] for pos, inp in enumerate(states): if inp in graph_inputs([Xt]): @@ -2637,9 +2639,9 @@ def compute_all_gradients(known_grads): continue # Just some trouble to avoid a +0 - if diff_outputs[i] in known_grads: + try: known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: + except KeyError: known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] dc_dxts_idx += 1 @@ -2655,6 +2657,9 @@ def compute_all_gradients(known_grads): ) else: disconnected_dC_dinps_t[dx] = False + # Replace inner output subexpressions with placeholders wired to the + # saved forward values, so the backward scan reuses them instead of + # recomputing them. See forced_replace docstring for details. for Xt, Xt_placeholder in zip( diff_outputs[info.n_mit_mot_outs :], Xts, strict=True ): @@ -2663,21 +2668,20 @@ def compute_all_gradients(known_grads): # construct dX_dtm1 dC_dXtm1s = [] + n_internal_recurrent_states = sum( + len(t) + for t in chain( + info.mit_mot_in_slices, + info.mit_sot_in_slices, + info.sit_sot_in_slices, + ) + ) for pos, x in enumerate(dC_dinps_t[info.n_seqs :]): - # Get the index of the first inner input corresponding to the - # pos-ieth inner input state + # Get the index of the first inner input corresponding to the pos-ieth inner input state idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos] - # Check if the pos-th input is associated with one of the - # recurrent states - x_is_state = pos < sum( - len(t) - for t in chain( - info.mit_mot_in_slices, - info.mit_sot_in_slices, - info.sit_sot_in_slices, - ) - ) + # Check if the pos-th input is associated with one of the recurrent states + x_is_state = pos < n_internal_recurrent_states if x_is_state and len(idxs) > 0: opos = idxs[0] @@ -2688,23 +2692,12 @@ def compute_all_gradients(known_grads): dC_dXtm1s.append(safe_new(x)) # Skip accumulation for "overlapping" mit-mot taps. - # # A mit-mot tap "overlaps" when the same tap index appears in both the input # and output slices of a single mit-mot state. This means the output *overwrites* # the input at that buffer position — analogous to set_subtensor(x, y, i). - # # The gradient of an overwrite must zero out the direct pass-through from the # old value; the only gradient path is through the output expression that replaced # it (already captured by compute_all_gradients via known_grads). - # - # The gradient for an overlapping tap is NOT zero — the chain-rule contribution - # through the output expression remains. We only skip the extra dC_dXtm1 - # accumulation term, which would incorrectly treat the old value as if it also - # passes through unchanged to future steps, double-counting the gradient. - # - # Overlapping taps arise naturally when differentiating sit-sot or mit-sot scans: - # their L_op converts the recurrence into a mit-mot where one tap serves as both - # read and write (e.g. in_taps=(0,1), out_taps=(1,) — tap 1 overlaps). overlapping_taps = set() dx_offset = 0 for idx in range(info.n_mit_mot): @@ -2791,8 +2784,7 @@ def compute_all_gradients(known_grads): outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] - # Restrict the length of the outer sequences to the number of grad - # steps + # Restrict the length of the outer sequences to the number of grad steps outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs] inner_inp_seqs = self.inner_seqs(self_inputs) @@ -2801,7 +2793,14 @@ def compute_all_gradients(known_grads): inner_inp_seqs += self.inner_sitsot(self_inputs) inner_inp_seqs += self.inner_nitsot_outs(dC_dXts) inner_inp_seqs += Xts - # mitmot + # Build backward scan's mit-mot states. + # Every forward recurrent state (sit-sot, mit-sot, mit-mot) becomes + # a mit-mot in the backward scan. The conversion negates the taps: + # forward output tap t → backward input tap -t (gradient signal) + # forward input tap t → backward output tap -t (gradient to propagate) + # Each backward output tap also needs a backward input tap at the same + # position to carry the accumulated gradient (the recurrence). If one + # already exists from the first rule, they share the buffer slot. outer_inp_mitmot = [] inner_inp_mitmot = [] inner_out_mitmot = [] @@ -2840,8 +2839,8 @@ def compute_all_gradients(known_grads): inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) if isinstance(dC_dinps_t[ins_pos].type, NullType): - # We cannot use Null in the inner graph, so we - # use a zero tensor of the appropriate shape instead. + # We cannot use Null in the inner graph, + # so we use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX) ) @@ -2949,9 +2948,8 @@ def compute_all_gradients(known_grads): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: if isinstance(dC_dinps_t[ins_pos].type, NullType): - # Cannot use dC_dinps_t[ins_pos].dtype, so we use - # floatX instead, as it is a dummy value that will not - # be used anyway. + # Cannot use dC_dinps_t[ins_pos].dtype, so we use floatX instead, + # as it is a dummy value that will not be used anyway. outer_inp_mitmot.append( pt.zeros(outs[idx + offset].shape, dtype=config.floatX) ) @@ -2963,8 +2961,8 @@ def compute_all_gradients(known_grads): ) if isinstance(dC_dinps_t[ins_pos].type, NullType): - # We cannot use Null in the inner graph, so we - # use a zero tensor of the appropriate shape instead. + # We cannot use Null in the inner graph, + # so we use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX) ) @@ -3004,8 +3002,7 @@ def compute_all_gradients(known_grads): through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) - # Replace the inner output with a zero tensor of - # the right shape + # Replace the inner output with a zero tensor of the right shape inner_out_sitsot[_p] = pt.zeros( diff_inputs[ins_pos + _p].shape, dtype=config.floatX ) @@ -3023,8 +3020,7 @@ def compute_all_gradients(known_grads): through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) - # Replace the inner output with a zero tensor of - # the right shape + # Replace the inner output with a zero tensor of the right shape inner_out_nitsot[_p] = pt.zeros( diff_inputs[_p].shape, dtype=config.floatX ) @@ -3119,9 +3115,8 @@ def compute_all_gradients(known_grads): ) ): if t == "connected": - # If the forward scan is in as_while mode, we need to pad - # the gradients, so that they match the size of the input - # sequences. + # If the forward scan is in as_while mode, we need to pad the gradients, + # so that they match the size of the input sequences. if info.as_while: n_zeros = inputs[0] - n_steps shp = (n_zeros,) @@ -3147,9 +3142,8 @@ def compute_all_gradients(known_grads): end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)): if t == "connected": - # If the forward scan is in as_while mode, we need to pad - # the gradients, so that they match the size of the input - # sequences. + # If the forward scan is in as_while mode, we need to pad the gradients, + # so that they match the size of the input sequences. if info.as_while: n_zeros = inputs[0] - grad_steps shp = (n_zeros,) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index abd6216110..b61f1b115e 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1279,6 +1279,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: The scan perform implementation takes the output sizes into consideration, saving the newest results over the oldest ones whenever the buffer is filled. + This rewrite must only run at compilation time, after grad() has already + built the backward scan. The backward scan needs all intermediate forward + states as sequence inputs (to evaluate f'(x[t])). If this rewrite truncates + buffers before grad() is called, the gradient will be silently wrong. + TODO: Use a subclass that raises explicitly on `L_op` + Paramaters ---------- backend_supports_output_pre_allocation: bool diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index 8ac3c58a99..b1426fa463 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -1099,9 +1099,22 @@ def __eq__(self, other): def forced_replace(out, x, y): """ - Check all internal values of the graph that compute the variable ``out`` - for occurrences of values identical with ``x``. If such occurrences are - encountered then they are replaced with variable ``y``. + Replace subexpressions in ``out`` that are structurally equal to ``x`` + with ``y``, using ``equal_computations`` for matching. + + Unlike ``graph_replace`` (which matches by variable identity), + this detects when a subexpression *recomputes* ``x`` without + being the same variable object. This is used by ``Scan.L_op`` + to substitute inner-function outputs with placeholders wired to + the saved forward values, avoiding redundant recomputation in + the backward scan. For example, if ``exp(x).L_op`` returns + ``output_gradient * exp(x)`` by recreating ``exp(x)`` instead + of referencing the existing output variable, a plain identity + check would miss it, but ``equal_computations`` catches it. + + This is not comprehensive: structurally different but semantically + equivalent expressions (e.g. ``exp(x + 0)`` vs ``exp(x)``) will + not match. Parameters ---------- @@ -1117,7 +1130,7 @@ def forced_replace(out, x, y): Notes ----- - When it find a match, it don't continue on the corresponding inputs. + When it finds a match, it does not continue into that node's inputs. """ if out is None: return None From 48f65423e51bdc3c2948ba6785ce4a901d3687e5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Mar 2026 20:14:52 +0100 Subject: [PATCH 3/7] Make ScalarOp.impl an abstract method All concrete ScalarOp subclasses already define impl (either directly or via ScalarInnerGraphOp). Making it abstract enforces this at instantiation time rather than at call time. Co-Authored-By: Claude Opus 4.6 (1M context) --- pytensor/scalar/basic.py | 4 +++- tests/scalar/test_basic.py | 4 ++-- tests/tensor/test_elemwise.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index b59cc9992f..1a89413aa7 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -10,6 +10,7 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar! """ +import abc import builtins import math from collections.abc import Callable @@ -1250,8 +1251,9 @@ def perform(self, node, inputs, output_storage): ): storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) + @abc.abstractmethod def impl(self, *inputs): - raise MethodNotDefined("impl", type(self), self.__class__.__name__) + raise NotImplementedError() def grad(self, inputs, output_gradients): raise MethodNotDefined("grad", type(self), self.__class__.__name__) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 3167a20149..dacd8ab487 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -193,8 +193,8 @@ class MultiOutOp(ScalarOp): def make_node(self, x): return Apply(self, [x], [x.type(), x.type()]) - def perform(self, node, inputs, outputs): - outputs[1][0] = outputs[0][0] = inputs[0] + def impl(self, x): + return x, x def c_code(self, *args): return "dummy" diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 8bd35bca28..f67760e893 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1118,7 +1118,7 @@ def make_node(self, *inputs): outputs = [float_op(), int_op()] return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): + def impl(self, *inputs): raise NotImplementedError() def L_op(self, inputs, outputs, output_gradients): From 94859a8c69746e448642f3b707abf535a963d36c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Mar 2026 20:16:25 +0100 Subject: [PATCH 4/7] Cleanup ScalarInnerGraphOp py_perform_fn to use impl directly Change py_perform_fn to convert inner ops via op.impl instead of wrapping op.perform with storage allocation. This removes per-element storage allocation overhead in Composite.impl and Composite.perform. Co-Authored-By: Claude Opus 4.6 (1M context) --- pytensor/scalar/basic.py | 45 ++++++++++++++-------------------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 1a89413aa7..0c25934f18 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -37,7 +37,6 @@ from pytensor.utils import ( apply_across_args, difference, - to_return_values, ) @@ -3842,7 +3841,7 @@ class Real(UnaryScalarOp): """ - # numpy.real(float32) return a view on the inputs. + # numpy.real(float32) return a view on the inputs, which ain't good for elemwise. # nfunc_spec = ('real', 1, 1) def impl(self, x): @@ -4056,39 +4055,27 @@ def inner_outputs(self): @property def py_perform_fn(self): - if hasattr(self, "_py_perform_fn"): + """Compiled Python function that chains inner ops' ``impl`` methods. + + Returns a callable that takes scalar inputs and returns a tuple of outputs. + """ + try: return self._py_perform_fn + except AttributeError: + pass from pytensor.link.utils import fgraph_to_python - def python_convert(op, node=None, **kwargs): - assert node is not None - - n_outs = len(node.outputs) - - if n_outs > 1: + def impl_convert(op, node=None, **kwargs): + return op.impl - def _perform(*inputs, outputs=[[None]] * n_outs): - op.perform(node, inputs, outputs) - return tuple(o[0] for o in outputs) - - else: - - def _perform(*inputs, outputs=[[None]]): - op.perform(node, inputs, outputs) - return outputs[0][0] - - return _perform - - self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) + self._py_perform_fn = fgraph_to_python(self.fgraph, impl_convert) return self._py_perform_fn def impl(self, *inputs): - output_storage = [[None] for i in range(self.nout)] - self.perform(None, inputs, output_storage) - ret = to_return_values([storage[0] for storage in output_storage]) - if self.nout > 1: - ret = tuple(ret) + ret = self.py_perform_fn(*inputs) + if self.nout == 1: + return ret[0] return ret def c_code_cache_version(self): @@ -4313,9 +4300,7 @@ def make_node(self, *inputs): return node def perform(self, node, inputs, output_storage): - outputs = self.py_perform_fn(*inputs) - # zip strict not specified because we are in a hot loop - for storage, out_val in zip(output_storage, outputs): + for storage, out_val in zip(output_storage, self.py_perform_fn(*inputs)): storage[0] = out_val def grad(self, inputs, output_grads): From 8f5303ce170a40b030c06d4129acdb0d395e5810 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Mar 2026 20:17:09 +0100 Subject: [PATCH 5/7] Cleanup Elemwise perform method Rewrite Elemwise.perform to use a single _create_node_ufunc method that builds a closure with inplace logic, dtype handling, and ufunc selection: - nfunc_spec path: uses numpy/scipy ufuncs directly with out= for inplace and sig= for discrete->float dtype coercion - frompyfunc path (<=32 operands): C iteration loop via np.frompyfunc, .astype() for object->correct dtype conversion. No inplace (destroy_map is a permission, not an obligation; frompyfunc already allocates). - Blockwise vectorize fallback (>32 operands): _vectorize_node_perform with inplace_mapping support - Scalar (0-d) outputs without nfunc_spec: calls impl directly with np.asarray wrapper Removes stale self.ufunc/self.nfunc attributes and __getstate__/__setstate__. Renames fake_node to dummy_node. Adds out= and inplace_mapping parameters to _vectorize_node_perform in blockwise. Co-Authored-By: Claude Opus 4.6 (1M context) --- pytensor/tensor/blockwise.py | 32 +++- pytensor/tensor/elemwise.py | 273 ++++++++++++++++++---------------- tests/tensor/test_elemwise.py | 5 +- 3 files changed, 180 insertions(+), 130 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 9c8fa86ac7..ab281bd352 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -49,10 +49,17 @@ def _vectorize_node_perform( batch_bcast_patterns: Sequence[tuple[bool, ...]], batch_ndim: int, impl: str | None, + inplace_mapping: tuple[int | None, ...] | None = None, ) -> Callable: """Creates a vectorized `perform` function for a given core node. Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op. + + Parameters + ---------- + inplace_mapping + Optional tuple of length ``nout``. Entry ``i`` is the input index that + output ``i`` should be written into, or ``None`` to allocate a fresh array. """ storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} @@ -75,6 +82,7 @@ def _vectorize_node_perform( def vectorized_perform( *args, + out=None, batch_bcast_patterns=batch_bcast_patterns, batch_ndim=batch_ndim, single_in=single_in, @@ -82,7 +90,11 @@ def vectorized_perform( core_input_storage=core_input_storage, core_output_storage=core_output_storage, core_storage=core_storage, + inplace_mapping=inplace_mapping, ): + if inplace_mapping is not None: + out = tuple(args[j] if j is not None else None for j in inplace_mapping) + if single_in: batch_shape = args[0].shape[:batch_ndim] else: @@ -106,10 +118,22 @@ def vectorized_perform( for core_input, arg in zip(core_input_storage, args): core_input[0] = np.asarray(arg[index0]) core_thunk() - outputs = tuple( - empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) - for core_output in core_output_storage - ) + if out is None: + outputs = tuple( + empty( + batch_shape + core_output[0].shape, dtype=core_output[0].dtype + ) + for core_output in core_output_storage + ) + else: + outputs = tuple( + o + if o is not None + else empty( + batch_shape + core_output[0].shape, dtype=core_output[0].dtype + ) + for o, core_output in zip(out, core_output_storage) + ) for output, core_output in zip(outputs, core_output_storage): output[index0] = core_output[0] diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 96446237a4..19a1a5280f 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from copy import copy from textwrap import dedent from typing import Literal @@ -357,6 +356,8 @@ def __init__( assert not isinstance(scalar_op, type(self)) if inplace_pattern is None: inplace_pattern = frozendict({}) + elif not isinstance(inplace_pattern, frozendict): + inplace_pattern = frozendict(inplace_pattern) self.name = name self.scalar_op = scalar_op self.inplace_pattern = inplace_pattern @@ -365,22 +366,8 @@ def __init__( if nfunc_spec is None: nfunc_spec = getattr(scalar_op, "nfunc_spec", None) self.nfunc_spec = nfunc_spec - self.__setstate__(self.__dict__) super().__init__(openmp=openmp) - def __getstate__(self): - d = copy(self.__dict__) - d.pop("ufunc") - d.pop("nfunc") - d.pop("__epydoc_asRoutine", None) - return d - - def __setstate__(self, d): - super().__setstate__(d) - self.ufunc = None - self.nfunc = None - self.inplace_pattern = frozendict(self.inplace_pattern) - def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. @@ -599,53 +586,149 @@ def transform(r): return ret - def prepare_node(self, node, storage_map, compute_map, impl): - # Postpone the ufunc building to the last minutes due to: - # - NumPy ufunc support only up to 32 operands (inputs and outputs) - # But our c code support more. - # - nfunc is reused for scipy and scipy is optional - if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py": - impl = "c" + def _create_node_ufunc(self, node: Apply): + """Define (or retrieve) the node ufunc used in `perform`. - if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = import_func_from_string(self.nfunc_spec[0]) + For scalar (0-d) outputs, calls ``scalar_op.impl`` directly. + For tensor outputs with ``nfunc_spec``, uses the numpy/scipy ufunc. + Otherwise, ``np.frompyfunc`` (≤32 operands) or Blockwise vectorize (>32). - if ( - (len(node.inputs) + len(node.outputs)) <= 32 - and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) - and self.ufunc is None - and impl == "py" - ): + All returned callables accept ``(*inputs)`` and return a tuple of outputs. + The ``inplace_pattern`` is baked into the closure so that inplace outputs + are written directly into the corresponding input arrays. + + The ufunc is stored in the tag of the node. + """ + inplace_pattern = self.inplace_pattern + nout = len(node.outputs) + out_dtypes = tuple(out.type.numpy_dtype for out in node.outputs) + # Pre-compute output→input index mapping for inplace + out_to_in = ( + tuple(inplace_pattern.get(i) for i in range(nout)) + if inplace_pattern + else () + ) + + if (nfunc_spec := self.nfunc_spec) is not None and len( + node.inputs + ) == nfunc_spec[1]: + ufunc = import_func_from_string(nfunc_spec[0]) + if ufunc is None: + raise ValueError(f"Could not import gufunc {nfunc_spec[0]} for {self}") + # When inputs are discrete and output is float, pass a signature + # to prevent numpy from computing in float16 for int8 inputs + ufunc_kwargs = {} + if ( + isinstance(ufunc, np.ufunc) + and any(inp.dtype in discrete_dtypes for inp in node.inputs) + and any(out.dtype in float_dtypes for out in node.outputs) + ): + in_sig = "".join(np.dtype(inp.dtype).char for inp in node.inputs) + out_sig = "".join(np.dtype(out.dtype).char for out in node.outputs) + ufunc_kwargs["sig"] = f"{in_sig}->{out_sig}" + + if out_to_in and isinstance(ufunc, np.ufunc): + # Only numpy ufuncs support out=; other nfunc_spec functions (e.g. np.where) don't + if nout == 1: + + def ufunc_fn( + *inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs, _j=out_to_in[0] + ): + _ufunc(*inputs, out=inputs[_j], **_kwargs) + return (inputs[_j],) + else: + + def ufunc_fn( + *inputs, + _ufunc=ufunc, + _kwargs=ufunc_kwargs, + _out_to_in=out_to_in, + ): + out = tuple( + inputs[j] if j is not None else None for j in _out_to_in + ) + return _ufunc(*inputs, out=out, **_kwargs) + elif nout == 1: + + def ufunc_fn(*inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs): + return (_ufunc(*inputs, **_kwargs),) + else: + + def ufunc_fn(*inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs): + return _ufunc(*inputs, **_kwargs) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + # No nfunc_spec path + if node.outputs[0].type.ndim == 0: + # Scalar outputs: call impl directly, wrap with np.asarray + impl = self.scalar_op.impl + if nout == 1: + + def ufunc_fn(*inputs, _impl=impl, _dt=out_dtypes[0]): + return (np.asarray(_impl(*inputs), dtype=_dt),) + else: + + def ufunc_fn(*inputs, _impl=impl, _dts=out_dtypes): + return tuple( + np.asarray(r, dtype=dt) for r, dt in zip(_impl(*inputs), _dts) + ) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + # ndim > 0 without nfunc_spec: frompyfunc (≤32 operands) or Blockwise vectorize (>32) + # frompyfunc returns object arrays — .astype() converts to the correct dtype. + # No inplace: destroy_map is a permission, not an obligation. frompyfunc + # already allocates an object array, so copying into the input would just waste time. + if len(node.inputs) + len(node.outputs) <= 32: ufunc = np.frompyfunc( self.scalar_op.impl, len(node.inputs), self.scalar_op.nout ) - if self.scalar_op.nin > 0: - # We can reuse it for many nodes - self.ufunc = ufunc + + if nout == 1: + + def ufunc_fn(*inputs, _ufunc=ufunc, _dt=out_dtypes[0]): + return (_ufunc(*inputs).astype(_dt),) else: - node.tag.ufunc = ufunc - - # Numpy ufuncs will sometimes perform operations in - # float16, in particular when the input is int8. - # This is not something that we want, and we do not - # do it in the C code, so we specify that the computation - # should be carried out in the returned dtype. - # This is done via the "sig" kwarg of the ufunc, its value - # should be something like "ff->f", where the characters - # represent the dtype of the inputs and outputs. - - # NumPy 1.10.1 raise an error when giving the signature - # when the input is complex. So add it only when inputs is int. - out_dtype = node.outputs[0].dtype - if ( - out_dtype in float_dtypes - and isinstance(self.nfunc, np.ufunc) - and node.inputs[0].dtype in discrete_dtypes - ): - char = np.dtype(out_dtype).char - sig = char * node.nin + "->" + char * node.nout - node.tag.sig = sig - node.tag.fake_node = Apply( + + def ufunc_fn(*inputs, _ufunc=ufunc, _dts=out_dtypes): + return tuple(r.astype(dt) for r, dt in zip(_ufunc(*inputs), _dts)) + else: + # frompyfunc limited to 32 operands, fall back to Blockwise vectorize + from pytensor.tensor.blockwise import _vectorize_node_perform + + core_node = Apply( + self.scalar_op, + [ + get_scalar_type(dtype=inp.type.dtype).make_variable() + for inp in node.inputs + ], + [ + get_scalar_type(dtype=out.type.dtype).make_variable() + for out in node.outputs + ], + ) + batch_ndim = node.outputs[0].type.ndim + batch_bcast_patterns = tuple(inp.type.broadcastable for inp in node.inputs) + ufunc_fn = _vectorize_node_perform( + core_node, + batch_bcast_patterns, + batch_ndim, + impl="py", + inplace_mapping=out_to_in or None, + ) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + def prepare_node(self, node, storage_map, compute_map, impl=None): + if impl != "c": + node.tag.ufunc = self._create_node_ufunc(node) + + # Create a dummy scalar node for the scalar_op to prepare itself + node.tag.dummy_node = dummy_node = Apply( self.scalar_op, [ get_scalar_type(dtype=input.type.dtype).make_variable() @@ -656,77 +739,18 @@ def prepare_node(self, node, storage_map, compute_map, impl): for output in node.outputs ], ) - - self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) + self.scalar_op.prepare_node(dummy_node, None, None, impl) def perform(self, node, inputs, output_storage): - if (len(node.inputs) + len(node.outputs)) > 32: - # Some versions of NumPy will segfault, other will raise a - # ValueError, if the number of operands in an ufunc is more than 32. - # In that case, the C version should be used, or Elemwise fusion - # should be disabled. - # FIXME: This no longer calls the C implementation! - super().perform(node, inputs, output_storage) - self._check_runtime_broadcast(node, inputs) - - ufunc_args = inputs - ufunc_kwargs = {} - # We supported in the past calling manually op.perform. - # To keep that support we need to sometimes call self.prepare_node - if self.nfunc is None and self.ufunc is None: - self.prepare_node(node, None, None, "py") - if self.nfunc and len(inputs) == self.nfunc_spec[1]: - ufunc = self.nfunc - nout = self.nfunc_spec[2] - if hasattr(node.tag, "sig"): - ufunc_kwargs["sig"] = node.tag.sig - # Unfortunately, the else case does not allow us to - # directly feed the destination arguments to the nfunc - # since it sometimes requires resizing. Doing this - # optimization is probably not worth the effort, since we - # should normally run the C version of the Op. - else: - # the second calling form is used because in certain versions of - # numpy the first (faster) version leads to segfaults - if self.ufunc: - ufunc = self.ufunc - elif not hasattr(node.tag, "ufunc"): - # It happen that make_thunk isn't called, like in - # get_underlying_scalar_constant_value - self.prepare_node(node, None, None, "py") - # prepare_node will add ufunc to self or the tag - # depending if we can reuse it or not. So we need to - # test both again. - if self.ufunc: - ufunc = self.ufunc - else: - ufunc = node.tag.ufunc - else: - ufunc = node.tag.ufunc - - nout = ufunc.nout + try: + ufunc = node.tag.ufunc + except AttributeError: + ufunc = node.tag.ufunc = self._create_node_ufunc(node) with np.errstate(all="ignore"): - variables = ufunc(*ufunc_args, **ufunc_kwargs) - - if nout == 1: - variables = [variables] - - # zip strict not specified because we are in a hot loop - for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs) - ): - storage[0] = variable = np.asarray(variable, dtype=nout.dtype) - - if i in self.inplace_pattern: - odat = inputs[self.inplace_pattern[i]] - odat[...] = variable - storage[0] = odat - - # numpy.real return a view! - if not variable.flags.owndata: - storage[0] = variable.copy() + for s, result in zip(output_storage, ufunc(*inputs)): + s[0] = result @staticmethod def _check_runtime_broadcast(node, inputs): @@ -754,8 +778,7 @@ def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...] def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` # To not request all of them to call prepare_node(), do it here. - # There is no harm if it get called multiple times. - if not hasattr(node.tag, "fake_node"): + if not hasattr(node.tag, "dummy_node"): self.prepare_node(node, None, None, "c") _inames = inames _onames = onames @@ -903,7 +926,7 @@ def _c_all(self, node, nodename, inames, onames, sub): else: fail = sub["fail"] task_code = self.scalar_op.c_code( - node.tag.fake_node, + node.tag.dummy_node, nodename + "_scalar_", [f"{s}_i" for s in _inames], [f"{s}_i" for s in onames], diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index f67760e893..95a9122237 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -417,7 +417,10 @@ def test_fill(self): xv = rval((5, 5)) yv = rval((1, 1)) f(xv, yv) - assert (xv == yv).all() + # destroy_map is a permission, not an obligation. + # PerformLinker with frompyfunc may not write inplace. + if not isinstance(linker(), PerformLinker): + assert (xv == yv).all() def test_fill_var(self): x = matrix() From a792f9f6772fc9cd128c786c8cba1a5b78edfb04 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Mar 2026 20:17:17 +0100 Subject: [PATCH 6/7] Add regression test for Composite with ops without C code Test that Composite.impl works with scalar ops that only define impl (no c_code). This exercises the py_perform_fn path with impl_convert. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/scalar/test_basic.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index dacd8ab487..3675721ebb 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -208,6 +208,35 @@ def c_code(self, *args): assert fn(1.0) == [1.0, 1.0] + def test_composite_without_c_code(self): + """Composite of scalar ops without C code should work for py-only execution.""" + from pytensor.scalar.basic import UnaryScalarOp, float64, upcast_out + + class _NoCodeExp(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.exp(x) + + def output_types(self, types): + return upcast_out(*types) + + class _NoCodeLog(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.log(x) + + def output_types(self, types): + return upcast_out(*types) + + xs = float64("xs") + comp = Composite( + [xs], + [_NoCodeExp(name="no_code_exp")(_NoCodeLog(name="no_code_log")(xs))], + ) + assert comp.impl(2.0) == pytest.approx(2.0) + class TestLogical: def test_gt(self): From ffda296ee7e431dd87f82821c7d75868dff5d5db Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 27 Mar 2026 20:17:28 +0100 Subject: [PATCH 7/7] Make FusionOptimizer backend-specific (C and Numba) FusionOptimizer now takes a backend parameter ("c" or "numba") that determines which ops are fuseable: - C fusion: "cxx_only" tag, checks scalar ops have C implementations (cached, since supports_c_code is expensive) - Numba fusion: "numba" tag, fuses unconditionally Python mode does not benefit from fusion: frompyfunc's C iteration loop is faster than fused Composite.impl per-element overhead. Also adds py-mode perform benchmarks (nfunc_spec vs frompyfunc paths) and dummy scalar ops for testing. Co-Authored-By: Claude Opus 4.6 (1M context) --- pytensor/tensor/rewriting/elemwise.py | 82 ++++++++++++------- tests/tensor/rewriting/test_elemwise.py | 104 +++++++++++++++++++++--- 2 files changed, 144 insertions(+), 42 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index dc30beedf3..4bc843fea6 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -516,15 +516,26 @@ def flatten_nested_add_mul(fgraph, node): return [output] -def elemwise_max_operands_fct(node) -> int: - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs) - if not config.cxx: - return 32 - return 1024 - - class FusionOptimizer(GraphRewriter): - """Graph optimizer that fuses consecutive Elemwise operations.""" + """Graph optimizer that fuses consecutive Elemwise operations. + + Parameters + ---------- + backend : str + The compilation backend: ``"c"`` or ``"numba"``. + ``"c"`` checks that every scalar op has a C implementation before fusing. + ``"numba"`` fuses unconditionally. + Python mode does not benefit from fusion (``frompyfunc`` iteration + in C is faster than fused ``Composite.impl`` per-element overhead). + """ + + def __init__(self, backend: str): + super().__init__() + if backend not in ("c", "numba"): + raise ValueError( + f"Unsupported backend {backend!r}. Expected 'c' or 'numba'." + ) + self.backend = backend def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) @@ -578,18 +589,23 @@ def find_fuseable_subgraphs( This function yields subgraph in reverse topological order so they can be safely replaced one at a time """ - @cache - def elemwise_scalar_op_has_c_code( - node: Apply, optimizer_verbose=config.optimizer_verbose - ) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + if self.backend == "c": + # supports_c_code is expensive, cache results + @cache + def elemwise_scalar_op_is_fuseable( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." + ) + return False + else: + # numba: fuse unconditionally + def elemwise_scalar_op_is_fuseable(node: Apply) -> bool: return True - elif optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False # Create a map from node to a set of fuseable client (successor) nodes # A node and a client are fuseable if they are both single output Elemwise @@ -608,7 +624,7 @@ def elemwise_scalar_op_has_c_code( out_node is not None and len(out_node.outputs) == 1 and isinstance(out_node.op, Elemwise) - and elemwise_scalar_op_has_c_code(out_node) + and elemwise_scalar_op_is_fuseable(out_node) ): continue @@ -621,7 +637,7 @@ def elemwise_scalar_op_has_c_code( len(client.outputs) == 1 and isinstance(client.op, Elemwise) and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) + and elemwise_scalar_op_is_fuseable(client) ) } if out_fuseable_clients: @@ -872,17 +888,10 @@ def elemwise_scalar_op_has_c_code( # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) - max_operands = elemwise_max_operands_fct(None) reason = self.__class__.__name__ nb_fused = 0 nb_replacement = 0 for inputs, outputs in find_fuseable_subgraphs(fgraph): - if (len(inputs) + len(outputs)) > max_operands: - warn( - "Loop fusion failed because the resulting node would exceed the kernel argument limit." - ) - continue - scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables @@ -1172,6 +1181,11 @@ def constant_fold_branches_of_add_mul(fgraph, node): ) # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) +# The outer SequenceDB is backend-agnostic; the actual FusionOptimizer inside +# is registered per-backend (C with "cxx_only", Numba with "numba"). +# Python mode does not benefit from fusion: frompyfunc's C iteration loop is +# faster than fused Composite.impl per-element overhead. +# Shared cleanup rewrites run for any backend that performed fusion. fuse_seqopt = SequenceDB() optdb.register( "elemwise_fusion", @@ -1182,13 +1196,22 @@ def constant_fold_branches_of_add_mul(fgraph, node): "FusionOptimizer", position=49, ) +# C backend fusion: checks that scalar ops have C implementations fuse_seqopt.register( "composite_elemwise_fusion", - FusionOptimizer(), + FusionOptimizer(backend="c"), "fast_run", "fusion", + "cxx_only", position=1, ) +# Numba backend fusion: fuses unconditionally +fuse_seqopt.register( + "numba_composite_elemwise_fusion", + FusionOptimizer(backend="numba"), + "numba", + position=1.5, +) fuse_seqopt.register( "local_useless_composite_outputs", dfs_rewriter(local_useless_composite_outputs), @@ -1201,6 +1224,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): dfs_rewriter(local_careduce_fusion), "fast_run", "fusion", + "cxx_only", position=10, ) fuse_seqopt.register( diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 2c196401a2..07ddf10cba 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -18,7 +18,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.raise_op import assert_op -from pytensor.scalar.basic import Composite, float64 +from pytensor.scalar.basic import Composite, UnaryScalarOp, float64, upcast_out from pytensor.tensor.basic import MakeVector from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import abs as pt_abs @@ -234,6 +234,7 @@ def test_local_useless_expand_dims_in_reshape(): assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)]) +@pytest.mark.skipif(not config.cxx, reason="Fusion requires a C compiler (cxx_only)") class TestFusion: rewrites = RewriteDatabaseQuery( include=[ @@ -242,7 +243,7 @@ class TestFusion: "add_mul_fusion", "inplace", ], - exclude=["cxx_only", "BlasOpt"], + exclude=["BlasOpt"], ) mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) @@ -315,7 +316,7 @@ def test_diamond_graph(): e = c + d fg = FunctionGraph([a], [e], clone=False) - _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + _, nb_fused, nb_replacement, *_ = FusionOptimizer(backend="c").apply(fg) assert nb_fused == 1 assert nb_replacement == 4 @@ -334,7 +335,7 @@ def test_expansion_order(self): e2 = d + b # test both orders fg = FunctionGraph([a], [e1, e2], clone=False) - _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + _, nb_fused, nb_replacement, *_ = FusionOptimizer(backend="c").apply(fg) fg.dprint() assert nb_fused == 1 assert nb_replacement == 3 @@ -1075,7 +1076,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): assert od == o.dtype def test_fusion_35_inputs(self): - r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" + r"""Make sure we can fuse 35 inputs with the C backend.""" inpts = vectors([f"i{i}" for i in range(35)]) # Make an elemwise graph looking like: @@ -1084,16 +1085,16 @@ def test_fusion_35_inputs(self): for idx in range(1, 35): out = sin(inpts[idx] + out) - with config.change_flags(cxx=""): - f = function(inpts, out, mode=self.mode) + f = function(inpts, out, mode=self.mode) - # Make sure they all weren't fused + # With the C backend, everything should be fused composite_nodes = [ node for node in f.maker.fgraph.toposort() if isinstance(getattr(node.op, "scalar_op", None), ps.basic.Composite) ] - assert not any(len(node.inputs) > 31 for node in composite_nodes) + assert len(composite_nodes) == 1 + assert composite_nodes[0].inputs.__len__() == 35 @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_big_fusion(self): @@ -1173,7 +1174,10 @@ def test_fusion_multiout_inplace(self, linker): inp = np.array([0, 1, 2], dtype=config.floatX) res = f(inp) - assert not np.allclose(inp, [0, 1, 2]) + # destroy_map is a permission, not an obligation. + # The C linker writes inplace; the py linker may not (e.g. frompyfunc path). + if linker != "py": + assert not np.allclose(inp, [0, 1, 2]) assert np.allclose(res[0], [1, 2, 3]) assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2])) @@ -1241,7 +1245,8 @@ def test_test_values(self, test_value): pt_all, np.all, marks=pytest.mark.xfail( - reason="Rewrite logic does not support all CAReduce" + strict=False, + reason="Rewrite logic does not support all CAReduce", ), ), ], @@ -1400,7 +1405,7 @@ def test_eval_benchmark(self, benchmark): def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) - opt = FusionOptimizer() + opt = FusionOptimizer(backend="c") def rewrite_func(): fg_clone = fg.clone() @@ -1444,7 +1449,9 @@ def test_joint_circular_dependency(self): for out_order in [(sub, add), (add, sub)]: fgraph = FunctionGraph([x], out_order, clone=True) - _, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph) + _, nb_fused, nb_replaced, *_ = FusionOptimizer(backend="c").apply( + fgraph + ) # (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion assert (nb_fused, nb_replaced) in ((2, 4), (1, 3)) fused_nodes = { @@ -1559,6 +1566,7 @@ def test_local_useless_composite_outputs(): utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) +@pytest.mark.skipif(not config.cxx, reason="Fusion requires a C compiler (cxx_only)") @pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)]) @pytest.mark.parametrize("op, np_op", [(pt.pow, np.power), (pt.add, np.add)]) def test_local_inline_composite_constants(op, np_op, const_shape): @@ -1654,3 +1662,73 @@ def test_InplaceElemwiseOptimizer_bug(): finally: # Restore original value to avoid affecting other tests pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value + + +# Dummy scalar ops without nfunc_spec — same impl as Exp/Log but forces +# the frompyfunc path (no numpy ufunc shortcut). +class _DummyExp(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.exp(x) + + def output_types(self, types): + return upcast_out(*types) + + +class _DummyLog(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.log(x) + + def output_types(self, types): + return upcast_out(*types) + + +_dummy_exp = Elemwise(_DummyExp(name="dummy_exp")) +_dummy_log = Elemwise(_DummyLog(name="dummy_log")) + + +class TestPyPerformBenchmarks: + """Benchmarks for the Python Elemwise perform path. + + These verify that: + 1. Ops with nfunc_spec (exp, log) use numpy ufuncs directly (SIMD). + 2. Ops without nfunc_spec use frompyfunc (C iteration loop). + """ + + rewrites = RewriteDatabaseQuery( + include=["fusion", "inplace"], + ) + py_mode = Mode("py", rewrites) + + def test_nfunc_spec(self, benchmark): + """sin(cos(x)) with nfunc_spec uses numpy ufuncs directly.""" + x = dvector("x") + out = pt.sin(pt.cos(x)) + f = function([x], out, mode=self.py_mode, trust_input=True) + + # Should be two separate Elemwise nodes (no py fusion) + elemwise_nodes = [ + n for n in f.maker.fgraph.toposort() if isinstance(n.op, Elemwise) + ] + assert len(elemwise_nodes) == 2 + + data = np.random.random(10_000) + benchmark(f, data) + + def test_no_nfunc_spec(self, benchmark): + """dummy_exp(dummy_log(x)) without nfunc_spec uses frompyfunc.""" + x = dvector("x") + out = _dummy_exp(_dummy_log(x)) + f = function([x], out, mode=self.py_mode, trust_input=True) + + # No py fusion — should be two separate Elemwise nodes + elemwise_nodes = [ + n for n in f.maker.fgraph.toposort() if isinstance(n.op, Elemwise) + ] + assert len(elemwise_nodes) == 2 + + data = np.random.random(10_000) + benchmark(f, data)