From 168468eab6d3afd6b992198c88c968e7c48e3ea8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 22:32:25 -0500 Subject: [PATCH 1/8] Add c_funcify C dispatch registry and route linkers through it Introduce the singledispatch c_funcify registry returning detached CImpl implementations, resolve CLinker through it, and route OpWiseCLinker, the VM, and DebugMode to the dispatched C thunk with a Python fallback. --- pytensor/compile/debug/debugmode.py | 40 ++-- pytensor/link/c/basic.py | 97 +++++++--- pytensor/link/c/dispatch/__init__.py | 4 + pytensor/link/c/dispatch/basic.py | 203 ++++++++++++++++++++ pytensor/link/vm.py | 12 +- tests/link/c/test_dispatch.py | 266 +++++++++++++++++++++++++++ 6 files changed, 578 insertions(+), 44 deletions(-) create mode 100644 pytensor/link/c/dispatch/__init__.py create mode 100644 pytensor/link/c/dispatch/basic.py create mode 100644 tests/link/c/test_dispatch.py diff --git a/pytensor/compile/debug/debugmode.py b/pytensor/compile/debug/debugmode.py index 9a9617774c..81ae22fad8 100644 --- a/pytensor/compile/debug/debugmode.py +++ b/pytensor/compile/debug/debugmode.py @@ -1306,7 +1306,10 @@ def printstuff(self): # List of default version of make thunk. # This is needed to know if the user overrode it. -default_make_thunk = [get_unbound_function(COp.make_thunk)] +default_make_thunk = [ + get_unbound_function(Op.make_thunk), + get_unbound_function(COp.make_thunk), +] # Debug mode cheats and initializes the linker in a different way in @@ -1352,6 +1355,7 @@ def make_all( # can't import at toplevel because of circular import TODO: # don't do this ugly hacky way of setting the # filter_checks_isfinite + from pytensor.link.c.dispatch.basic import c_thunk_from_dispatch fgraph = self.fgraph input_storage_ = input_storage @@ -1404,15 +1408,10 @@ def make_all( debug = hasattr(node.op, "debug_perform") try: - if ( - not self.maker.mode.check_c_code - or debug - or not isinstance(node.op, COp) - ): + if not self.maker.mode.check_c_code or debug: raise MethodNotDefined() - node.op.prepare_node(node, storage_map, compute_map, "c") - thunk = node.op.make_c_thunk( + thunk = c_thunk_from_dispatch( node, storage_map, compute_map, no_recycling ) thunks_c.append(thunk) @@ -1434,19 +1433,18 @@ def make_all( else: thunks_py.append(None) - if ( - not self.maker.mode.check_c_code - and thunks_py[-1] is None - and isinstance(node.op, COp) - ): - _logger.warning( - f"Op {node.op} doesn't have a perform, forcing check of the C code" - ) - node.op.prepare_node(node, storage_map, compute_map, "c") - thunk = node.op.make_c_thunk( - node, storage_map, compute_map, no_recycling - ) - thunks_c[-1] = thunk + if not self.maker.mode.check_c_code and thunks_py[-1] is None: + try: + thunk = c_thunk_from_dispatch( + node, storage_map, compute_map, no_recycling + ) + except (NotImplementedError, MethodNotDefined): + pass + else: + _logger.warning( + f"Op {node.op} doesn't have a perform, forcing check of the C code" + ) + thunks_c[-1] = thunk # If the op defined its own make_thunk, use the generated thunk if thunk_other is not None: diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index d39a6a9881..81e953c642 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -13,6 +13,7 @@ from pytensor.compile.compilelock import lock_ctx from pytensor.configdefaults import config from pytensor.graph.basic import ( + Apply, AtomicVariable, Constant, ) @@ -553,6 +554,7 @@ class CLinker(Linker): def __init__(self, schedule=None): self.fgraph = None + self._node_impls: dict[Apply, CLinkerOp] = {} super().__init__(scheduler=schedule) def accept( @@ -573,6 +575,23 @@ def accept( self.no_recycling = no_recycling return self + def _impl_for(self, node) -> CLinkerOp: + """Return `node`'s C implementation, resolved and memoized via `c_funcify`.""" + try: + return self._node_impls[node] + except KeyError: + pass + # Imported lazily so `import pytensor` does not load the dispatch + # registrations; the import triggers them on first compilation. + from pytensor.link.c.dispatch.basic import c_funcify + + try: + impl = c_funcify(node.op, node=node) + except NotImplementedError as exc: + raise NotImplementedError(f"{node.op} cannot produce C code") from exc + self._node_impls[node] = impl + return impl + def fetch_variables(self): """Fills the inputs, outputs, variables, orphans, temps and node_order fields.""" fgraph = self.fgraph @@ -591,10 +610,14 @@ def fetch_variables(self): # that needs it self.node_params = dict() for node in self.node_order: - if not isinstance(node.op, CLinkerOp): + try: + impl = self._impl_for(node) + except NotImplementedError: + # No C implementation; code_gen will raise if this node is + # actually compiled. continue try: - params = node.op.get_params(node) + params = impl.get_params(node) except MethodNotDefined: params = NoParams if params is not NoParams: @@ -602,10 +625,10 @@ def fetch_variables(self): # same params. if params in self.node_params: var = self.node_params[params] - assert var.type == node.params_type + assert var.type == impl.params_type fgraph.clients[var].append((node, "params")) else: - var = Constant(node.params_type, params) + var = Constant(impl.params_type, params) fgraph.clients[var] = [(node, "params")] self.node_params[params] = var self.variables.append(var) @@ -775,10 +798,7 @@ def code_gen(self): id += 2 for node_num, node in enumerate(self.node_order): - op = node.op - - if not isinstance(op, CLinkerOp): - raise NotImplementedError(f"{op} cannot produce C code") + op = self._impl_for(node) sub = dict(failure_var=failure_var) @@ -905,7 +925,9 @@ def support_code(self): """ ) # generic support code - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: support_code = x.c_support_code() if isinstance(support_code, list): ret.extend(support_code) @@ -941,7 +963,9 @@ def compile_args(self): c_compiler = self.c_compiler() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_compile_args(c_compiler=c_compiler) @@ -949,7 +973,9 @@ def compile_args(self): # The args set by the compiler include the user flags. We do not want # to reorder them ret += c_compiler.compile_args() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): no_comp = x.c_no_compile_args(c_compiler=c_compiler) @@ -970,7 +996,9 @@ def headers(self): """ ret = [] c_compiler = self.c_compiler() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_headers(c_compiler=c_compiler) return uniq(ret) @@ -984,14 +1012,18 @@ def init_code(self): """ ret = [] - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_init_code() return uniq(ret) def c_compiler(self): c_compiler = None - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: # FIXME: Why would a `Type` have a `c_compiler` field?! if hasattr(x, "c_compiler"): x_compiler = x.c_compiler() @@ -1021,7 +1053,9 @@ def header_dirs(self): """ ret = [] c_compiler = self.c_compiler() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_header_dirs(c_compiler=c_compiler) # filter out empty strings/None @@ -1037,7 +1071,9 @@ def libraries(self): """ ret = [] c_compiler = self.c_compiler() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_libraries(c_compiler=c_compiler) return uniq(ret) @@ -1052,7 +1088,9 @@ def lib_dirs(self): """ ret = [] c_compiler = self.c_compiler() - for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: + for x in [y.type for y in self.variables] + [ + self._impl_for(y) for y in self.node_order + ]: if isinstance(x, CLinkerObject): ret += x.c_lib_dirs(c_compiler=c_compiler) # filter out empty strings/None @@ -1433,8 +1471,14 @@ def in_sig(i, topological_pos, i_idx): version = [] for node_pos, node in enumerate(order): - if hasattr(node.op, "c_code_cache_version_apply"): - version.append(node.op.c_code_cache_version_apply(node)) + try: + impl = self._impl_for(node) + except NotImplementedError: + # No C implementation: contributes no version entry (code_gen + # raises later if this graph is compiled). + pass + else: + version.append(impl.c_code_cache_version_apply(node)) props = getattr(node.op, "__props__", None) @@ -1819,6 +1863,8 @@ def accept(self, fgraph, no_recycling=None, profile=None): def make_all( self, profiler=None, input_storage=None, output_storage=None, storage_map=None ): + from pytensor.link.c.dispatch.basic import make_node_thunk_with_c_dispatch + fgraph = self.fgraph order = self.schedule(fgraph) no_recycling = self.no_recycling @@ -1838,9 +1884,16 @@ def make_all( thunks = [] for node in order: - # make_thunk will try by default C code, otherwise - # it fall back to python. - thunks += [node.op.make_thunk(node, storage_map, compute_map, no_recycling)] + # Try the C dispatch first, otherwise fall back to Python. + thunks += [ + make_node_thunk_with_c_dispatch( + node, + storage_map, + compute_map, + no_recycling, + try_c=bool(config.cxx), + ) + ] thunks[-1].inputs = [storage_map[v] for v in node.inputs] thunks[-1].outputs = [storage_map[v] for v in node.outputs] diff --git a/pytensor/link/c/dispatch/__init__.py b/pytensor/link/c/dispatch/__init__.py new file mode 100644 index 0000000000..2cf370a098 --- /dev/null +++ b/pytensor/link/c/dispatch/__init__.py @@ -0,0 +1,4 @@ +# isort: off +from pytensor.link.c.dispatch.basic import c_funcify + +# isort: on diff --git a/pytensor/link/c/dispatch/basic.py b/pytensor/link/c/dispatch/basic.py new file mode 100644 index 0000000000..c8abbf7fcb --- /dev/null +++ b/pytensor/link/c/dispatch/basic.py @@ -0,0 +1,203 @@ +import warnings +from collections.abc import Collection +from functools import singledispatch +from typing import NoReturn + +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType +from pytensor.graph.utils import MethodNotDefined +from pytensor.link.c.interface import CLinkerOp +from pytensor.link.c.op import ( + COp, + CThunkWrapperType, + is_cthunk_wrapper_type, +) + + +@singledispatch +def c_funcify(op: Op, node: Apply | None = None, **kwargs) -> CLinkerOp: + """Return the C implementation of `op` at `node`. + + By default an op implementing `CLinkerOp` (every `COp`) is its own + implementation; otherwise raise `NotImplementedError` and let the caller fall + back to the Python thunk. + """ + if isinstance(op, CLinkerOp): + return op + raise NotImplementedError(f"No C implementation registered for {type(op).__name__}") + + +def _hashable_aliasing_map(aliasing_map: dict[int, list[int]]) -> tuple: + return tuple(sorted((idx, tuple(vals)) for idx, vals in aliasing_map.items())) + + +class CImpl(CLinkerOp): + """A C implementation of an `Op`, detached from the op. + + Returned by `c_funcify`; never a graph op. Subclasses that add configuration + must extend `_impl_props`, which backs equality, hashing, and the cache key. + """ + + # `Apply.clone_with_new_inputs` reads this off the node's op when the + # single-node graph is cloned for compilation; impl outputs never depend on + # input values (the graph op already fixed the output types). + _output_type_depends_on_input_value = False + + def __init__( + self, + op: Op, + *, + destroy_map: dict[int, list[int]] | None = None, + view_map: dict[int, list[int]] | None = None, + ): + self.op = op + if destroy_map is None: + destroy_map = getattr(op, "destroy_map", {}) + if view_map is None: + view_map = getattr(op, "view_map", {}) + self.destroy_map = destroy_map + self.view_map = view_map + + def _impl_props(self) -> tuple: + return ( + self.op, + _hashable_aliasing_map(self.destroy_map), + _hashable_aliasing_map(self.view_map), + ) + + def __eq__(self, other) -> bool: + return type(self) is type(other) and self._impl_props() == other._impl_props() + + def __hash__(self) -> int: + return hash((type(self), *self._impl_props())) + + def __str__(self) -> str: + return f"{type(self).__name__}{{{self.op}}}" + + def make_node(self, *inputs) -> NoReturn: + raise RuntimeError( + f"{type(self).__name__} is a C implementation, not a graph op." + ) + + def prepare_node( + self, + node: Apply, + storage_map: StorageMapType, + compute_map: ComputeMapType | None, + impl: str | None, + ) -> None: + """No-op: C preparation happens when `c_funcify` constructs the impl.""" + + +def c_thunk_from_dispatch( + node: Apply, + storage_map: StorageMapType, + compute_map: ComputeMapType | None, + no_recycling: Collection[Variable], +) -> CThunkWrapperType: + """Compile a C thunk for `node`, taking its implementation from `c_funcify`. + + Raises + ------ + NotImplementedError + If `node.op` has no C implementation, or has float16 inputs/outputs. + MethodNotDefined + If the implementation declines this node (e.g. an unsupported dtype). + + Callers fall back to a Python thunk on either. + """ + # Imported here to avoid an import cycle. + import pytensor.link.c.basic + + # Resolve eagerly so an unimplemented op raises before prepare_node runs and + # before any compilation work; CLinker re-resolves (memoized) during codegen. + c_funcify(node.op, node=node) + + node.op.prepare_node( + node, storage_map=storage_map, compute_map=compute_map, impl="c" + ) + + node_input_storage = [storage_map[r] for r in node.inputs] + node_output_storage = [storage_map[r] for r in node.outputs] + + fgraph = FunctionGraph(node.inputs, node.outputs) + fgraph_no_recycling = [ + new_o + for (new_o, old_o) in zip(fgraph.outputs, node.outputs, strict=True) + if old_o in no_recycling + ] + cl = pytensor.link.c.basic.CLinker().accept( + fgraph, no_recycling=fgraph_no_recycling + ) + + # float16 gets special treatment since running unprepared C code will get bad + # results. + if not getattr(node.op, "_f16_ok", False): + + def is_f16(t): + return getattr(t, "dtype", "") == "float16" + + if any(is_f16(i.type) for i in node.inputs) or any( + is_f16(o.type) for o in node.outputs + ): + # get_dynamic_module just tries to build the C code; it raises for + # impls without C code, in which case we don't want to warn. + cl.get_dynamic_module() + warnings.warn(f"Disabling C code for {node.op} due to unsupported float16") + raise NotImplementedError("float16") + + outputs = cl.make_thunk( + input_storage=node_input_storage, output_storage=node_output_storage + ) + thunk, _node_input_filters, _node_output_filters = outputs + + if compute_map is None: + rval = is_cthunk_wrapper_type(thunk) + else: + cm_entries = [compute_map[o] for o in node.outputs] + + @is_cthunk_wrapper_type + def rval(thunk=thunk, cm_entries=cm_entries): + thunk() + for entry in cm_entries: + entry[0] = True + + rval.thunk = thunk + rval.cthunk = thunk.cthunk + rval.inputs = node_input_storage + rval.outputs = node_output_storage + rval.lazy = False + return rval + + +# Ops whose `make_thunk` is one of these run through the dispatch; anything else +# overrode `make_thunk` (e.g. `IfElse`, `Scan`) and keeps its custom path. +_DEFAULT_MAKE_THUNKS = (Op.make_thunk, COp.make_thunk) + + +def make_node_thunk_with_c_dispatch( + node: Apply, + storage_map: StorageMapType, + compute_map: ComputeMapType | None, + no_recycling: Collection[Variable], + *, + try_c: bool, + fallback_impl: str | None = None, +) -> ThunkType: + """Make a thunk for `node`, trying the C dispatch first when `try_c`. + + When the C attempt fails (no implementation, or the implementation declines + the node) the fallback passes ``impl="py"`` so `COp.make_thunk` does not + retry the C path. + """ + if try_c and type(node.op).make_thunk in _DEFAULT_MAKE_THUNKS: + try: + return c_thunk_from_dispatch(node, storage_map, compute_map, no_recycling) + except (NotImplementedError, MethodNotDefined): + fallback_impl = "py" + # Op.make_thunk is untyped upstream; pin the result to its real type. + thunk: ThunkType = node.op.make_thunk( + node, storage_map, compute_map, no_recycling, impl=fallback_impl + ) + return thunk diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 239f73df80..a2dfe91f43 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -1210,6 +1210,8 @@ def make_all( output_storage=None, storage_map=None, ): + from pytensor.link.c.dispatch.basic import make_node_thunk_with_c_dispatch + fgraph = self.fgraph order = self.schedule(fgraph) @@ -1227,6 +1229,7 @@ def make_all( impl = None if self.c_thunks is False: impl = "py" + use_c_dispatch = self.c_thunks is not False and bool(config.cxx) for node in order: try: thunk_start = time.perf_counter() @@ -1234,7 +1237,14 @@ def make_all( # no need to cause duplicate c code by passing # no_recycling here. thunks.append( - node.op.make_thunk(node, storage_map, compute_map, [], impl=impl) + make_node_thunk_with_c_dispatch( + node, + storage_map, + compute_map, + [], + try_c=use_c_dispatch, + fallback_impl=impl, + ) ) linker_make_thunk_time[node] = time.perf_counter() - thunk_start if not hasattr(thunks[-1], "lazy"): diff --git a/tests/link/c/test_dispatch.py b/tests/link/c/test_dispatch.py new file mode 100644 index 0000000000..3afbb5b5bb --- /dev/null +++ b/tests/link/c/test_dispatch.py @@ -0,0 +1,266 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.scalar as ps +import pytensor.tensor as pt +from pytensor.compile.debug.debugmode import BadThunkOutput, DebugMode +from pytensor.compile.mode import Mode +from pytensor.configdefaults import config +from pytensor.graph.basic import Apply +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op +from pytensor.graph.utils import MethodNotDefined +from pytensor.link.c.basic import CLinker +from pytensor.link.c.dispatch.basic import ( + CImpl, + c_funcify, + c_thunk_from_dispatch, +) +from pytensor.link.vm import VMLinker +from pytensor.tensor.shape import Shape, Shape_i + + +pytestmark = pytest.mark.skipif( + not config.cxx, reason="A C compiler is required to test the C dispatch" +) + +CVM_MODE = Mode(linker="cvm", optimizer=None) +PY_MODE = Mode(linker="py", optimizer=None) + + +class ScalarOpBase(Op): + """A pure scalar op: only `make_node` and `perform`.""" + + __props__ = () + increment = 1.0 + + def make_node(self, x): + x = ps.as_scalar(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + (x,) = inputs + output_storage[0][0] = np.dtype(node.outputs[0].dtype).type(x + self.increment) + + +class IncOne(ScalarOpBase): + pass + + +class IncOneNoImpl(ScalarOpBase): + pass + + +class IncOneDeclining(ScalarOpBase): + pass + + +class IncOneImpl(CImpl): + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs + return f"{z} = {x} + 1;" + + def c_code_cache_version(self): + return (1,) + + +class DecliningImpl(CImpl): + def c_code(self, node, name, inputs, outputs, sub): + raise MethodNotDefined("c_code") + + def c_code_cache_version(self): + return (1,) + + +@c_funcify.register(IncOne) +def c_funcify_inc_one(op, node=None, **kwargs): + return IncOneImpl(op) + + +@c_funcify.register(IncOneDeclining) +def c_funcify_declining(op, node=None, **kwargs): + return DecliningImpl(op) + + +def make_thunk_for(op, x_value=2.0, dtype="float64"): + x = ps.ScalarType(dtype)("x") + out = op(x) + node = out.owner + storage_map = {x: [np.dtype(dtype).type(x_value)], out: [None]} + compute_map = {x: [True], out: [False]} + thunk = c_thunk_from_dispatch(node, storage_map, compute_map, []) + return thunk, storage_map, compute_map, out + + +def test_pure_op_gains_c_thunk(): + thunk, storage_map, compute_map, out = make_thunk_for(IncOne()) + + assert hasattr(thunk, "cthunk") + assert thunk.lazy is False + assert thunk.inputs == [storage_map[out.owner.inputs[0]]] + assert thunk.outputs == [storage_map[out]] + + thunk() + assert storage_map[out][0] == 3.0 + assert compute_map[out][0] is True + + +def test_pure_op_cvm_function_matches_perform(): + x = ps.float64("x") + out = IncOne()(x) + + f_c = pytensor.function([x], out, mode=CVM_MODE) + f_py = pytensor.function([x], out, mode=PY_MODE) + assert f_c(2.0) == f_py(2.0) == 3.0 + + +def test_unregistered_pure_op_falls_back(): + op = IncOneNoImpl() + with pytest.raises(NotImplementedError, match="No C implementation registered"): + c_funcify(op) + + x = ps.float64("x") + f = pytensor.function([x], op(x), mode=CVM_MODE) + assert f(2.0) == 3.0 + + +def test_declining_impl_falls_back(): + op = IncOneDeclining() + with pytest.raises(MethodNotDefined): + make_thunk_for(op) + + x = ps.float64("x") + f = pytensor.function([x], op(x), mode=CVM_MODE) + assert f(2.0) == 3.0 + + +def test_cop_is_its_own_impl(): + op = Shape() + assert c_funcify(op) is op + + +def test_float16_guard_falls_back(): + # The guard raises NotImplementedError either way; the warning is only + # emitted when the impl's C code builds for f16, but ScalarType's own C + # support rejects f16 first. + op = IncOne() + with pytest.raises(NotImplementedError): + make_thunk_for(op, dtype="float16") + + x = ps.ScalarType("float16")("x") + f = pytensor.function([x], op(x), mode=CVM_MODE) + assert f(np.float16(2.0)) == np.float16(3.0) + + +def test_vm_without_c_thunks_skips_dispatch(monkeypatch): + def fail_dispatch(*args, **kwargs): + raise AssertionError("dispatch should not run when c_thunks=False") + + monkeypatch.setattr( + "pytensor.link.c.dispatch.basic.c_thunk_from_dispatch", fail_dispatch + ) + + x = ps.float64("x") + mode = Mode(linker=VMLinker(use_cloop=False, c_thunks=False), optimizer=None) + f = pytensor.function([x], IncOne()(x), mode=mode) + assert f(2.0) == 3.0 + + +@pytest.mark.parametrize("linker", ["c", "c|py"]) +def test_whole_graph_linkers_use_dispatch(linker): + x = ps.float64("x") + f = pytensor.function([x], IncOne()(x), mode=Mode(linker=linker, optimizer=None)) + assert f(2.0) == 3.0 + + +def test_whole_graph_c_linker_unregistered_raises(): + x = ps.float64("x") + with pytest.raises(NotImplementedError, match="cannot produce C code"): + pytensor.function([x], IncOneNoImpl()(x), mode=Mode(linker="c", optimizer=None)) + + +def test_cmodule_key_stable_and_versioned(): + def key_for_fresh_graph(): + x = ps.float64("x") + out = IncOne()(x) + fgraph = FunctionGraph([x], [out]) + return CLinker().accept(fgraph).cmodule_key() + + key_a = key_for_fresh_graph() + key_b = key_for_fresh_graph() + assert key_a == key_b + + version, _sig = key_a + # The registered impl's cache version makes the module versioned (cacheable + # across processes), even though the graph op itself has no C methods. + assert version != () + assert IncOneImpl(IncOne()).c_code_cache_version() in version + + +def test_params_constants_deduplicated_across_nodes(): + x = pt.matrix("x") + y = pt.matrix("y") + out = Shape_i(0)(x) + Shape_i(0)(y) + + fgraph = FunctionGraph([x, y], [out]) + cl = CLinker().accept(fgraph) + shape_i_nodes = [n for n in cl.node_order if isinstance(n.op, Shape_i)] + assert len(shape_i_nodes) == 2 + # Both Shape_i(0) nodes share one params Constant. + assert len(cl.node_params) == 1 + + f = pytensor.function([x, y], out, mode=Mode(linker="c", optimizer=None)) + assert f(np.ones((3, 2)), np.ones((5, 2))) == 8 + + +class IncOneWrongImpl(ScalarOpBase): + pass + + +class WrongImpl(CImpl): + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs + return f"{z} = {x} + 2;" # disagrees with perform on purpose + + def c_code_cache_version(self): + return (1,) + + +@c_funcify.register(IncOneWrongImpl) +def c_funcify_wrong(op, node=None, **kwargs): + return WrongImpl(op) + + +def test_cop_graph_resolves_to_identity(): + # The parity guarantee: every COp node resolves to itself, so CLinker calls + # the op's own c_code/cache-version methods and produces byte-identical + # source and cache keys. + x = pt.matrix("x") + out = (x.T + 1.0).sum(axis=0) + fgraph = FunctionGraph([x], [out]) + cl = CLinker().accept(fgraph) + + for node in cl.node_order: + assert cl._impl_for(node) is node.op + + # Source generation works and the module is versioned (cacheable). + assert isinstance(cl.get_src_code(), str) + version, _sig = cl.cmodule_key() + assert version != () + + +def test_debugmode_cross_checks_dispatch_impl(): + x = ps.float64("x") + f = pytensor.function( + [x], IncOne()(x), mode=DebugMode(optimizer=None, check_py_code=True) + ) + assert f(2.0) == 3.0 + + f_wrong = pytensor.function( + [x], IncOneWrongImpl()(x), mode=DebugMode(optimizer=None, check_py_code=True) + ) + with pytest.raises(BadThunkOutput): + f_wrong(2.0) From 9a4a72076ffc9c3bf6beddece8d850ba79fee510 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 22:33:27 -0500 Subject: [PATCH 2/8] Migrate CheckAndRaise to the C dispatch registry Make CheckAndRaise a plain Op and register CheckAndRaiseImpl, moving its c_code and ParamsType into the detached impl. --- pytensor/link/c/dispatch/__init__.py | 3 + pytensor/link/c/dispatch/raise_op.py | 87 ++++++++++++++++++++++++++++ pytensor/raise_op.py | 64 +------------------- 3 files changed, 93 insertions(+), 61 deletions(-) create mode 100644 pytensor/link/c/dispatch/raise_op.py diff --git a/pytensor/link/c/dispatch/__init__.py b/pytensor/link/c/dispatch/__init__.py index 2cf370a098..18bd97ee9a 100644 --- a/pytensor/link/c/dispatch/__init__.py +++ b/pytensor/link/c/dispatch/__init__.py @@ -1,4 +1,7 @@ # isort: off from pytensor.link.c.dispatch.basic import c_funcify +# Load dispatch specializations +import pytensor.link.c.dispatch.raise_op + # isort: on diff --git a/pytensor/link/c/dispatch/raise_op.py b/pytensor/link/c/dispatch/raise_op.py new file mode 100644 index 0000000000..384178f9fd --- /dev/null +++ b/pytensor/link/c/dispatch/raise_op.py @@ -0,0 +1,87 @@ +from collections.abc import Hashable +from textwrap import indent + +from pytensor.graph.basic import Apply +from pytensor.link.c.dispatch.basic import CImpl, c_funcify +from pytensor.link.c.params_type import Params, ParamsType +from pytensor.link.c.type import Generic +from pytensor.raise_op import CheckAndRaise +from pytensor.scalar.basic import ScalarType +from pytensor.tensor.type import DenseTensorType + + +class ExceptionType(Generic): + def __eq__(self, other): + return type(self) is type(other) + + def __hash__(self): + return hash(type(self)) + + +exception_type = ExceptionType() + + +class CheckAndRaiseImpl(CImpl): + """C implementation of `CheckAndRaise`. + + The exception type is a runtime ``PyObject`` passed through a `ParamsType`; + the message is baked into the generated code (it is part of the op's props). + """ + + op: CheckAndRaise + params_type = ParamsType(exc_type=exception_type) + + def get_params(self, node: Apply) -> Params: + return self.params_type.get_params(self.op) + + def c_code_cache_version(self) -> tuple[Hashable, ...]: + return (2,) + + def c_code( + self, + node: Apply, + name: str, + inputs: list[str], + outputs: list[str], + sub: dict[str, str], + ) -> str: + if not isinstance(node.inputs[0].type, DenseTensorType | ScalarType): + raise NotImplementedError( + f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" + ) + value_name, *cond_names = inputs + out_name = outputs[0] + fail_code = sub["fail"] + param_struct_name = sub["params"] + msg = self.op.msg.replace('"', '\\"').replace("\n", "\\n") + + all_conds = " && ".join(cond_names) + check = f""" + if(!({all_conds})) {{ + PyObject * exc_type = {param_struct_name}->exc_type; + Py_INCREF(exc_type); + PyErr_SetString(exc_type, "{msg}"); + Py_XDECREF(exc_type); + {indent(fail_code, " " * 4)} + }} + """ + + if isinstance(node.inputs[0].type, DenseTensorType): + res = f""" + {check} + Py_XDECREF({out_name}); + {out_name} = {value_name}; + Py_INCREF({value_name}); + """ + else: + res = f""" + {check} + {out_name} = {value_name}; + """ + + return "\n".join((check, res)) + + +@c_funcify.register(CheckAndRaise) +def c_funcify_check_and_raise(op, node=None, **kwargs) -> CheckAndRaiseImpl: + return CheckAndRaiseImpl(op) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index c2962e4d7d..afb1f553ca 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -1,34 +1,17 @@ """Symbolic Op for raising an exception.""" -from textwrap import indent - from pytensor.gradient import disconnected_type from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node -from pytensor.link.c.op import COp -from pytensor.link.c.params_type import ParamsType -from pytensor.link.c.type import Generic -from pytensor.scalar.basic import ScalarType, as_scalar -from pytensor.tensor.type import DenseTensorType - - -class ExceptionType(Generic): - def __eq__(self, other): - return type(self) is type(other) - - def __hash__(self): - return hash(type(self)) +from pytensor.scalar.basic import as_scalar -exception_type = ExceptionType() - - -class CheckAndRaise(COp): +class CheckAndRaise(Op): """An `Op` that checks conditions and raises an exception if they fail. This `Op` returns its "value" argument if its condition arguments are all ``True``; otherwise, it raises a user-specified exception. - """ _f16_ok = True @@ -36,7 +19,6 @@ class CheckAndRaise(COp): view_map = {0: [0]} check_input = False - params_type = ParamsType(exc_type=exception_type) def __init__(self, exc_type, msg=""): if not issubclass(exc_type, Exception): @@ -97,46 +79,6 @@ def pullback(self, input, outputs, output_gradients): def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) - def c_code(self, node, name, inames, onames, props): - if not isinstance(node.inputs[0].type, DenseTensorType | ScalarType): - raise NotImplementedError( - f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" - ) - value_name, *cond_names = inames - out_name = onames[0] - fail_code = props["fail"] - param_struct_name = props["params"] - msg = self.msg.replace('"', '\\"').replace("\n", "\\n") - - all_conds = " && ".join(cond_names) - check = f""" - if(!({all_conds})) {{ - PyObject * exc_type = {param_struct_name}->exc_type; - Py_INCREF(exc_type); - PyErr_SetString(exc_type, "{msg}"); - Py_XDECREF(exc_type); - {indent(fail_code, " " * 4)} - }} - """ - - if isinstance(node.inputs[0].type, DenseTensorType): - res = f""" - {check} - Py_XDECREF({out_name}); - {out_name} = {value_name}; - Py_INCREF({value_name}); - """ - else: - res = f""" - {check} - {out_name} = {value_name}; - """ - - return "\n".join((check, res)) - - def c_code_cache_version(self): - return (2,) - def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] From a1b6a6bc0c289e4f18437e1f9c909d8f04e3b0f3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 22:34:44 -0500 Subject: [PATCH 3/8] Migrate DimShuffle to the C dispatch registry Make DimShuffle a plain Op and register DimShuffleImpl, which emits per-node specialized C from the static permutation and shape, replacing the runtime-spec dimshuffle.c kernel. --- pytensor/link/c/dispatch/__init__.py | 1 + pytensor/link/c/dispatch/elemwise.py | 114 +++++++++++++++++++++++++++ pytensor/tensor/c_code/dimshuffle.c | 86 -------------------- pytensor/tensor/elemwise.py | 28 ++----- tests/link/c/test_dispatch.py | 12 ++- 5 files changed, 133 insertions(+), 108 deletions(-) create mode 100644 pytensor/link/c/dispatch/elemwise.py delete mode 100644 pytensor/tensor/c_code/dimshuffle.c diff --git a/pytensor/link/c/dispatch/__init__.py b/pytensor/link/c/dispatch/__init__.py index 18bd97ee9a..c08169c0e4 100644 --- a/pytensor/link/c/dispatch/__init__.py +++ b/pytensor/link/c/dispatch/__init__.py @@ -2,6 +2,7 @@ from pytensor.link.c.dispatch.basic import c_funcify # Load dispatch specializations +import pytensor.link.c.dispatch.elemwise import pytensor.link.c.dispatch.raise_op # isort: on diff --git a/pytensor/link/c/dispatch/elemwise.py b/pytensor/link/c/dispatch/elemwise.py new file mode 100644 index 0000000000..1fb112bcdb --- /dev/null +++ b/pytensor/link/c/dispatch/elemwise.py @@ -0,0 +1,114 @@ +from collections.abc import Hashable + +from pytensor.graph.basic import Apply +from pytensor.link.c.dispatch.basic import CImpl, c_funcify +from pytensor.tensor.elemwise import DimShuffle + + +class DimShuffleImpl(CImpl): + """Specialized C implementation of `DimShuffle`. + + Reads the op's static permutation and the input's static shape to emit a + straight-line view construction; the only runtime guards are squeeze checks + on dropped axes whose length is statically unknown. + """ + + op: DimShuffle + + def c_code_cache_version(self) -> tuple[Hashable, ...]: + # `new_order`/`input_ndim` ride `DimShuffle.__props__` and the input's + # static shape rides its type signature, so both are keyed automatically; + # bump this only when the emitted C below changes. + return (1,) + + def c_code( + self, + node: Apply, + name: str, + inputs: list[str], + outputs: list[str], + sub: dict[str, str], + ) -> str: + op = self.op + (inp,) = inputs + (out,) = outputs + fail = sub["fail"] + new_order = op._new_order + nd_out = len(new_order) + in_shape = node.inputs[0].type.shape + + # A dropped axis is guaranteed by `make_node` to be length 1 or unknown. + # Only the unknown case needs a runtime check. + guards = "\n".join( + f""" + if (PyArray_DIMS({inp})[{d}] != 1) {{ + PyErr_SetString(PyExc_ValueError, + "DimShuffle: cannot drop axis {d} with length not equal to one."); + {fail} + }}""" + for d in op.drop + if in_shape[d] is None + ) + + assigns = [] + for i, j in enumerate(new_order): + if j == -1: + # An augmented (broadcast) axis. The length-1 stride is set to the + # itemsize rather than zero: the value is never dereferenced, but + # some BLAS implementations mishandle a zero stride. + assigns.append(f"dimensions[{i}] = 1;") + assigns.append(f"strides[{i}] = itemsize;") + else: + assigns.append(f"dimensions[{i}] = PyArray_DIMS({inp})[{j}];") + static_len = in_shape[j] + if static_len == 1: + assigns.append(f"strides[{i}] = itemsize;") + elif static_len is not None: + assigns.append(f"strides[{i}] = PyArray_STRIDES({inp})[{j}];") + else: + assigns.append( + f"strides[{i}] = PyArray_DIMS({inp})[{j}] == 1 ? " + f"itemsize : PyArray_STRIDES({inp})[{j}];" + ) + + if nd_out: + shape_block = ( + f"npy_intp dimensions[{nd_out}];\n" + f"npy_intp strides[{nd_out}];\n" + "\n".join(assigns) + ) + dims_ptr = "dimensions" + strides_ptr = "strides" + else: + shape_block = "" + dims_ptr = "NULL" + strides_ptr = "NULL" + + return f""" + {{ + npy_intp itemsize = PyArray_ITEMSIZE({inp}); + {guards} + {shape_block} + + Py_XDECREF({out}); + // Borrow only the writable flag from the input; NPY_OWNDATA stays 0. + {out} = (PyArrayObject*)PyArray_New( + &PyArray_Type, {nd_out}, {dims_ptr}, + PyArray_TYPE({inp}), {strides_ptr}, + PyArray_DATA({inp}), itemsize, + (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE({inp})), + NULL); + if ({out} == NULL) {{ + {fail} + }} + + // Declare the result a view of the input and recompute its flags. + Py_INCREF((PyObject*){inp}); + PyArray_SetBaseObject({out}, (PyObject*){inp}); + PyArray_UpdateFlags({out}, NPY_ARRAY_UPDATE_ALL); + }} + """ + + +@c_funcify.register(DimShuffle) +def c_funcify_dimshuffle(op, node=None, **kwargs) -> DimShuffleImpl: + return DimShuffleImpl(op) diff --git a/pytensor/tensor/c_code/dimshuffle.c b/pytensor/tensor/c_code/dimshuffle.c deleted file mode 100644 index 0bfc5df3bb..0000000000 --- a/pytensor/tensor/c_code/dimshuffle.c +++ /dev/null @@ -1,86 +0,0 @@ -#section support_code_apply - -int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) { - npy_int64* new_order; - npy_intp nd_in; - npy_intp nd_out; - npy_intp* dimensions; - npy_intp* strides; - - if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { - PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); - return 1; - } - new_order = (npy_int64*) PyArray_DATA(params->_new_order); - nd_in = (npy_intp)(params->input_ndim); - nd_out = PyArray_SIZE(params->_new_order); - - if (PyArray_NDIM(input) != nd_in) { - PyErr_SetString(PyExc_ValueError, "DimShuffle: Input has less dimensions than expected."); - return 1; - } - - // Compute new dimensions and strides - dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); - strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); - if (dimensions == NULL || strides == NULL) { - PyErr_NoMemory(); - free(dimensions); - free(strides); - return 1; - }; - - npy_intp original_size = PyArray_SIZE(input); - npy_intp new_size = 1; - for (npy_intp i = 0; i < nd_out; ++i) { - // We set the strides of length 1 dimensions to PyArray_ITEMSIZE(input). - // The value is arbitrary, because there is never a next element. - // np.expand_dims(x, 0) and x[None] do different things here. - // I would prefer zero, but there are some poorly implemented BLAS operations - // That don't handle zero strides correctly. At least they won't fail because of DimShuffle. - if (new_order[i] != -1) { - dimensions[i] = PyArray_DIMS(input)[new_order[i]]; - strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? PyArray_ITEMSIZE(input) : PyArray_STRIDES(input)[new_order[i]]; - } else { - dimensions[i] = 1; - strides[i] = PyArray_ITEMSIZE(input); - } - new_size *= dimensions[i]; - } - - if (original_size != new_size) { - PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one."); - free(dimensions); - free(strides); - return 1; - } - - if (*res) - Py_XDECREF(*res); - - // Create the new array. - *res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, - PyArray_TYPE(input), strides, - PyArray_DATA(input), PyArray_ITEMSIZE(input), - // borrow only the writable flag from the base - // the NPY_OWNDATA flag will default to 0. - (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)), - NULL); - - if (*res == NULL) { - free(dimensions); - free(strides); - return 1; - } - - // Declare it a view of the original input - Py_INCREF((PyObject*)input); - PyArray_SetBaseObject(*res, (PyObject*)input); - - // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED - PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); - - free(strides); - free(dimensions); - return 0; -} \ No newline at end of file diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a2a6e2a89f..799fce7eca 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -11,16 +11,16 @@ from pytensor.gradient import DisconnectedType, disconnected_type from pytensor.graph.basic import Apply from pytensor.graph.null_type import NullType +from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code -from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp -from pytensor.link.c.params_type import ParamsType +from pytensor.link.c.op import COp, OpenMPOp from pytensor.misc.frozendict import frozendict from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import int64, upcast +from pytensor.scalar.basic import upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -29,7 +29,6 @@ continuous_dtypes, discrete_dtypes, float_dtypes, - lvector, ) from pytensor.tensor.utils import ( broadcast_static_dim_lengths, @@ -40,7 +39,7 @@ from pytensor.utils import uniq, unzip -class DimShuffle(ExternalCOp): +class DimShuffle(Op): """ Allows to reorder the dimensions of a tensor or insert or remove broadcastable dimensions. @@ -114,20 +113,9 @@ class DimShuffle(ExternalCOp): _f16_ok = True check_input = False __props__ = ("input_ndim", "new_order") - c_func_file = "c_code/dimshuffle.c" - c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" view_map = {0: [0]} - @property - def params_type(self): - return ParamsType( - _new_order=lvector, - input_ndim=int64, - ) - def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): - super().__init__([self.c_func_file], self.c_func_name) - if not isinstance(input_ndim, int): raise TypeError(f"input_ndim must be an integer, got {type(int)}") @@ -187,11 +175,11 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.is_matrix_transpose = not augment and is_left_expanded_matrix_transpose def __setstate__(self, state): + # Old pickles carry ExternalCOp attributes (func_files, ...); drop them, + # the C implementation now comes from the dispatch registry. + for key in ("func_files", "func_codes", "func_name", "code_sections"): + state.pop(key, None) self.__dict__.update(state) - if not hasattr(self, "func_files"): - # Perhaps we are loading an old `Op` version of DimShuffle. - # Let's just build the ExternalCOp. - super().__init__([self.c_func_file], self.c_func_name) def make_node(self, inp): input = as_tensor_variable(inp) diff --git a/tests/link/c/test_dispatch.py b/tests/link/c/test_dispatch.py index 3afbb5b5bb..8f78551c3f 100644 --- a/tests/link/c/test_dispatch.py +++ b/tests/link/c/test_dispatch.py @@ -237,14 +237,22 @@ def c_funcify_wrong(op, node=None, **kwargs): def test_cop_graph_resolves_to_identity(): # The parity guarantee: every COp node resolves to itself, so CLinker calls # the op's own c_code/cache-version methods and produces byte-identical - # source and cache keys. + # source and cache keys. A registered op (DimShuffle) resolves to its + # detached impl instead. + from pytensor.link.c.dispatch.elemwise import DimShuffleImpl + from pytensor.tensor.elemwise import DimShuffle + x = pt.matrix("x") out = (x.T + 1.0).sum(axis=0) fgraph = FunctionGraph([x], [out]) cl = CLinker().accept(fgraph) for node in cl.node_order: - assert cl._impl_for(node) is node.op + impl = cl._impl_for(node) + if isinstance(node.op, DimShuffle): + assert isinstance(impl, DimShuffleImpl) + else: + assert impl is node.op # Source generation works and the module is versioned (cacheable). assert isinstance(cl.get_src_code(), str) From ffa8712da072dbab3cf4d3f440cdeb9e3b984ed5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 22:52:58 -0500 Subject: [PATCH 4/8] Remove unused ExternalCOp DimShuffle was its last consumer and now uses the C dispatch registry, so the ExternalCOp base class, its section-loading machinery, and the tests exercising it are dead code. --- pytensor/link/c/op.py | 427 +----------------- tests/link/c/c_code/test_quadratic_function.c | 44 -- tests/link/c/test_op.py | 95 ---- tests/link/c/test_params_type.py | 39 +- 4 files changed, 7 insertions(+), 598 deletions(-) delete mode 100644 tests/link/c/c_code/test_quadratic_function.c diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 2a0170f98d..b0956cf21f 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -1,21 +1,12 @@ -import inspect -import re import warnings -from collections.abc import Callable, Collection, Iterable -from pathlib import Path -from re import Pattern -from typing import TYPE_CHECKING, Any, ClassVar, cast - -import numpy as np +from collections.abc import Callable, Collection +from typing import TYPE_CHECKING, cast from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType -from pytensor.graph.type import HasDataType from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.interface import CLinkerOp -from pytensor.link.c.params_type import ParamsType -from pytensor.utils import hash_from_code if TYPE_CHECKING: @@ -224,407 +215,6 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.update_self_openmp() -def lquote_macro(txt: str) -> str: - """Turn the last line of text into a ``\\``-commented line.""" - return " \\\n".join(txt.split("\n")) - - -def get_sub_macros(sub: dict[str, str]) -> tuple[str, str]: - define_macros = [] - undef_macros = [] - define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}") - undef_macros.append("#undef FAIL") - if "params" in sub: - define_macros.append(f"#define PARAMS {sub['params']}") - undef_macros.append("#undef PARAMS") - - return "\n".join(define_macros), "\n".join(undef_macros) - - -def get_io_macros(inputs: list[str], outputs: list[str]) -> tuple[str, str]: - define_inputs = [f"#define INPUT_{int(i)} {inp}" for i, inp in enumerate(inputs)] - define_outputs = [f"#define OUTPUT_{int(i)} {out}" for i, out in enumerate(outputs)] - - undef_inputs = [f"#undef INPUT_{int(i)}" for i in range(len(inputs))] - undef_outputs = [f"#undef OUTPUT_{int(i)}" for i in range(len(outputs))] - - define_all = "\n".join(define_inputs + define_outputs) - undef_all = "\n".join(undef_inputs + undef_outputs) - - return define_all, undef_all - - -class ExternalCOp(COp): - """Class for an `Op` with an external C implementation. - - One can inherit from this class, provide its constructor with a path to - an external C source file and the name of a function within it, and define - an `Op` for said function. - - """ - - section_re: ClassVar[Pattern] = re.compile( - r"^#section ([a-zA-Z0-9_]+)$", re.MULTILINE - ) - backward_re: ClassVar[Pattern] = re.compile( - r"^PYTENSOR_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE - ) - # This is the set of allowed markers - SECTIONS: ClassVar[set[str]] = { - "init_code", - "init_code_apply", - "init_code_struct", - "support_code", - "support_code_apply", - "support_code_struct", - "cleanup_code_struct", - "code", - "code_cleanup", - } - _cop_num_inputs: int | None = None - _cop_num_outputs: int | None = None - - @classmethod - def get_path(cls, f: Path) -> Path: - """Convert a path relative to the location of the class file into an absolute path. - - Paths that are already absolute are passed through unchanged. - - """ - if not f.is_absolute(): - class_file = inspect.getfile(cls) - class_dir = Path(class_file).parent - f = (class_dir / f).resolve() - return f - - def __init__( - self, - func_files: str | Path | list[str] | list[Path], - func_name: str | None = None, - ): - """ - Sections are loaded from files in order with sections in later - files overriding sections in previous files. - - """ - if not isinstance(func_files, list): - self.func_files = [Path(func_files)] - else: - self.func_files = [Path(func_file) for func_file in func_files] - - self.func_codes: list[str] = [] - # Keep the original name. If we reload old pickle, we want to - # find the new path and new version of the file in PyTensor. - self.func_name = func_name - self.code_sections: dict[str, str] = dict() - - self.load_c_code(self.func_files) - - if len(self.code_sections) == 0: - raise ValueError("No sections where defined in the C files") - - if self.func_name is not None: - if "op_code" in self.code_sections: - # maybe a warning instead (and clearing the key) - raise ValueError( - "Cannot have an `op_code` section and specify `func_name`" - ) - if "op_code_cleanup" in self.code_sections: - # maybe a warning instead (and clearing the key) - raise ValueError( - "Cannot have an `op_code_cleanup` section and specify `func_name`" - ) - - def load_c_code(self, func_files: Iterable[Path]) -> None: - """Loads the C code to perform the `Op`.""" - for func_file in func_files: - func_file = self.get_path(func_file) - self.func_codes.append(func_file.read_text(encoding="utf-8")) - - # If both the old section markers and the new section markers are - # present, raise an error because we don't know which ones to follow. - old_markers_present = any( - self.backward_re.search(code) for code in self.func_codes - ) - new_markers_present = any( - self.section_re.search(code) for code in self.func_codes - ) - - if old_markers_present and new_markers_present: - raise ValueError( - "Both the new and the old syntax for " - "identifying code sections are present in the " - "provided C code. These two syntaxes should not " - "be used at the same time." - ) - - for func_file, code in zip(func_files, self.func_codes, strict=True): - if self.backward_re.search(code): - # This is backward compat code that will go away in a while - - # Separate the code into the proper sections - split = self.backward_re.split(code) - n = 1 - while n < len(split): - if split[n] == "APPLY": - self.code_sections["support_code_apply"] = split[n + 1] - elif split[n] == "SUPPORT": - self.code_sections["support_code"] = split[n + 1] - n += 2 - continue - - elif self.section_re.search(code): - # Check for code outside of the supported sections - split = self.section_re.split(code) - if split[0].strip() != "": - raise ValueError( - "Stray code before first #section " - f"statement (in file {func_file}): {split[0]}" - ) - - # Separate the code into the proper sections - n = 1 - while n < len(split): - if split[n] not in self.SECTIONS: - raise ValueError( - f"Unknown section type (in file {func_file}): {split[n]}" - ) - if split[n] not in self.code_sections: - self.code_sections[split[n]] = "" - self.code_sections[split[n]] += split[n + 1] - n += 2 - - else: - raise ValueError( - f"No valid section marker was found in file {func_file}" - ) - - def __get_op_params(self) -> list[tuple[str, Any]]: - """Construct name, value pairs that will be turned into macros for use within the `Op`'s code. - - The names must be strings that are not a C keyword and the - values must be strings of literal C representations. - - If op uses a :class:`pytensor.graph.params_type.ParamsType` as ``params_type``, - it returns: - - a default macro ``PARAMS_TYPE`` which defines the class name of the - corresponding C struct. - - a macro ``DTYPE_PARAM_key`` for every ``key`` in the :class:`ParamsType` for which associated - type implements the method :func:`pytensor.graph.type.CLinkerType.c_element_type`. - ``DTYPE_PARAM_key`` defines the primitive C type name of an item in a variable - associated to ``key``. - - """ - params: list[tuple[str, Any]] = [] - if isinstance(self.params_type, ParamsType): - wrapper = self.params_type - params.append(("PARAMS_TYPE", wrapper.name)) - for i in range(wrapper.length): - c_type = wrapper.types[i].c_element_type() - if c_type: - # NB (reminder): These macros are currently used only in ParamsType example test - # (`pytensor/graph/tests/test_quadratic_function.c`), to demonstrate how we can - # access params dtypes when dtypes may change (e.g. if based on config.floatX). - # But in practice, params types generally have fixed types per op. - params.append( - ( - "DTYPE_PARAM_" + wrapper.fields[i], - c_type, - ) - ) - return params - - def c_code_cache_version(self): - version = (hash_from_code("\n".join(self.func_codes)),) - if self.params_type is not None: - version += (self.params_type.c_code_cache_version(),) - return version - - def c_init_code(self, **kwargs): - if "init_code" in self.code_sections: - return [self.code_sections["init_code"]] - else: - return super().c_init_code(**kwargs) - - def c_support_code(self, **kwargs): - if "support_code" in self.code_sections: - return self.code_sections["support_code"] - else: - return super().c_support_code(**kwargs) - - def c_init_code_apply(self, node, name): - if "init_code_apply" in self.code_sections: - code = self.code_sections["init_code_apply"] - - define_macros, undef_macros = self.get_c_macros(node, name) - return f"\n{define_macros}\n{code}\n{undef_macros}" - else: - return super().c_init_code_apply(node, name) - - def c_support_code_apply(self, node, name): - if "support_code_apply" in self.code_sections: - code = self.code_sections["support_code_apply"] - - define_macros, undef_macros = self.get_c_macros(node, name) - return f"\n{define_macros}\n{code}\n{undef_macros}" - else: - return super().c_support_code_apply(node, name) - - def c_support_code_struct(self, node, name): - if "support_code_struct" in self.code_sections: - code = self.code_sections["support_code_struct"] - - define_macros, undef_macros = self.get_c_macros(node, name) - return f"\n{define_macros}\n{code}\n{undef_macros}" - else: - return super().c_support_code_struct(node, name) - - def c_cleanup_code_struct(self, node, name): - if "cleanup_code_struct" in self.code_sections: - code = self.code_sections["cleanup_code_struct"] - - define_macros, undef_macros = self.get_c_macros(node, name) - return f"\n{define_macros}\n{code}\n{undef_macros}" - else: - return super().c_cleanup_code_struct(node, name) - - def format_c_function_args(self, inp: list[str], out: list[str]) -> str: - """Generate a string containing the arguments sent to the external C function. - - The result will have the format: ``"input0, input1, input2, &output0, &output1"``. - - """ - inp = list(inp) - if self._cop_num_inputs is not None: - numi = self._cop_num_inputs - else: - numi = len(inp) - - while len(inp) < numi: - inp.append("NULL") - - out = [f"&{o}" for o in out] - - if self._cop_num_outputs is not None: - numo = self._cop_num_outputs - else: - numo = len(out) - - while len(out) < numo: - out.append("NULL") - - return ", ".join(inp + out) - - def get_c_macros( - self, node: Apply, name: str, check_input: bool | None = None - ) -> tuple[str, str]: - "Construct a pair of C ``#define`` and ``#undef`` code strings." - define_macros = [] - undef_macros = [] - - if check_input is None: - check_input = getattr(self, "check_input", True) - - if check_input: - # Extract the various properties of the input and output variables - variables = node.inputs + node.outputs - variable_names = [f"INPUT_{i}" for i in range(len(node.inputs))] + [ - f"OUTPUT_{i}" for i in range(len(node.outputs)) - ] - - # Generate dtype macros - for i, v in enumerate(variables): - if not isinstance(v.type, HasDataType): - continue - - vname = variable_names[i] - - define_macros.append(f"#define DTYPE_{vname} npy_{v.type.dtype}") - undef_macros.append(f"#undef DTYPE_{vname}") - - d = np.dtype(v.type.dtype) - - define_macros.append(f"#define TYPENUM_{vname} {d.num}") - undef_macros.append(f"#undef TYPENUM_{vname}") - - define_macros.append(f"#define ITEMSIZE_{vname} {d.itemsize}") - undef_macros.append(f"#undef ITEMSIZE_{vname}") - - # Generate a macro to mark code as being apply-specific - define_macros.append(f"#define APPLY_SPECIFIC(str) str##_{name}") - undef_macros.append("#undef APPLY_SPECIFIC") - - define_macros.extend(f"#define {n} {v}" for n, v in self.__get_op_params()) - undef_macros.extend(f"#undef {n}" for n, _ in self.__get_op_params()) - - return "\n".join(define_macros), "\n".join(undef_macros) - - def c_init_code_struct(self, node, name, sub): - r"""Stitches all the macros and ``init_code_*``\s together.""" - if "init_code_struct" in self.code_sections: - op_code = self.code_sections["init_code_struct"] - - def_macros, undef_macros = self.get_c_macros(node, name) - def_sub, undef_sub = get_sub_macros(sub) - - return f"\n{def_macros}\n{def_sub}\n{op_code}\n{undef_sub}\n{undef_macros}" - else: - return super().c_init_code_struct(node, name, sub) - - def c_code(self, node, name, inp, out, sub): - if self.func_name is not None: - assert "code" not in self.code_sections - - define_macros, undef_macros = self.get_c_macros( - node, name, check_input=False - ) - - params = "" - if "params" in sub: - params = f", {sub['params']}" - - # Generate the C code - return f""" - {define_macros} - {{ - if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{ - {sub["fail"]} - }} - }} - {undef_macros} - """ - else: - if "code" in self.code_sections: - op_code = self.code_sections["code"] - - def_macros, undef_macros = self.get_c_macros(node, name) - def_sub, undef_sub = get_sub_macros(sub) - def_io, undef_io = get_io_macros(inp, out) - - return ( - f"{def_macros}\n{def_sub}\n{def_io}\n{op_code}" - f"\n{undef_io}\n{undef_sub}\n{undef_macros}" - ) - else: - raise NotImplementedError() - - def c_code_cleanup(self, node, name, inputs, outputs, sub): - r"""Stitches all the macros and ``code_cleanup``\s together.""" - if "code_cleanup" in self.code_sections: - op_code = self.code_sections["code_cleanup"] - - def_macros, undef_macros = self.get_c_macros(node, name) - def_sub, undef_sub = get_sub_macros(sub) - def_io, undef_io = get_io_macros(inputs, outputs) - - return ( - f"{def_macros}\n{def_sub}\n{def_io}\n{op_code}" - f"\n{undef_io}\n{undef_sub}\n{undef_macros}" - ) - else: - return super().c_code_cleanup(node, name, inputs, outputs, sub) - - class _NoPythonCOp(COp): """A class used to indicate that a `COp` does not provide a Python implementation. @@ -634,16 +224,3 @@ class _NoPythonCOp(COp): def perform(self, node, inputs, output_storage): raise NotImplementedError("No Python implementation is provided by this COp.") - - -class _NoPythonExternalCOp(ExternalCOp): - """A class used to indicate that an `ExternalCOp` does not provide a Python implementation. - - XXX: Do not use this class; it's only for tracking bad implementations internally. - - """ - - def perform(self, node, inputs, output_storage): - raise NotImplementedError( - "No Python implementation is provided by this ExternalCOp." - ) diff --git a/tests/link/c/c_code/test_quadratic_function.c b/tests/link/c/c_code/test_quadratic_function.c deleted file mode 100644 index cea63832a3..0000000000 --- a/tests/link/c/c_code/test_quadratic_function.c +++ /dev/null @@ -1,44 +0,0 @@ -#section support_code_apply -int APPLY_SPECIFIC(quadratic_function)(PyArrayObject* tensor, DTYPE_INPUT_0 a, DTYPE_INPUT_0 b, DTYPE_INPUT_0 c) { - NpyIter* iterator = NpyIter_New(tensor, - NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, - NPY_KEEPORDER, NPY_NO_CASTING, NULL); - if(iterator == NULL) { - PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a tensor for an elemwise operation."); - return -1; - } - NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL); - char** data_ptr = NpyIter_GetDataPtrArray(iterator); - npy_intp* stride_ptr = NpyIter_GetInnerStrideArray(iterator); - npy_intp* innersize_ptr = NpyIter_GetInnerLoopSizePtr(iterator); - do { - char* data = *data_ptr; - npy_intp stride = *stride_ptr; - npy_intp count = *innersize_ptr; - while(count) { - DTYPE_INPUT_0 x = *((DTYPE_INPUT_0*)data); - *((DTYPE_INPUT_0*)data) = a*x*x + b*x + c; - data += stride; - --count; - } - } while(get_next(iterator)); - NpyIter_Deallocate(iterator); - return 0; -} - -int APPLY_SPECIFIC(compute_quadratic)(PyArrayObject* X, PyArrayObject** Y, PARAMS_TYPE* coeff) { - DTYPE_INPUT_0 a = (DTYPE_INPUT_0) (*(DTYPE_PARAM_a*) PyArray_GETPTR1(coeff->a, 0)); // 0-D TensorType. - DTYPE_INPUT_0 b = coeff->b; // ScalarType. - DTYPE_INPUT_0 c = (DTYPE_INPUT_0) PyFloat_AsDouble(coeff->c); // Generic. - Py_XDECREF(*Y); - *Y = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(X), PyArray_DIMS(X), TYPENUM_INPUT_0, PyArray_IS_F_CONTIGUOUS(X)); - if (PyArray_CopyInto(*Y, X) != 0) { - PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output."); - return 1; - }; - if (APPLY_SPECIFIC(quadratic_function)(*Y, a, b, c) != 0) { - PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function."); - return 1; - } - return 0; -} diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index 4cf6058a78..be731c7687 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -1,9 +1,3 @@ -import os -import string -import subprocess -import sys -from pathlib import Path - import numpy as np import pytest @@ -15,52 +9,6 @@ from pytensor.link.c.op import COp -test_dir = Path(__file__).parent.absolute() - -externalcop_test_code = f""" -from pytensor import tensor as pt -from pytensor.graph.basic import Apply -from pytensor.link.c.params_type import ParamsType -from pytensor.link.c.op import ExternalCOp -from pytensor.scalar import ScalarType -from pytensor.link.c.type import Generic -from pytensor.tensor.type import TensorType - -tensor_type_0d = TensorType("float64", tuple()) -scalar_type = ScalarType("float64") -generic_type = Generic() - - -class QuadraticCOpFunc(ExternalCOp): - __props__ = ("a", "b", "c") - params_type = ParamsType(a=tensor_type_0d, b=scalar_type, c=generic_type) - - def __init__(self, a, b, c): - super().__init__( - "{str(test_dir).replace(os.sep, "/")}/c_code/test_quadratic_function.c", "APPLY_SPECIFIC(compute_quadratic)" - ) - self.a = a - self.b = b - self.c = c - - def make_node(self, x): - x = pt.as_tensor_variable(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inputs, output_storage, coefficients): - x = inputs[0] - y = output_storage[0] - y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c - - -if __name__ == "__main__": - qcop = QuadraticCOpFunc(1, 2, 3) - - print(qcop.c_code_cache_version()) - print("__success__") -""" - - class StructOp(COp): __props__ = () @@ -189,46 +137,3 @@ def perform(self, *args, **kwargs): else: with pytest.raises((NotImplementedError, MethodNotDefined)): thunk() - - -def get_hash(modname, seed=None): - """From https://hg.python.org/cpython/file/5e8fa1b13516/Lib/test/test_hash.py#l145""" - env = os.environ.copy() - if seed is not None: - env["PYTHONHASHSEED"] = str(seed) - else: - env.pop("PYTHONHASHSEED", None) - cmd_line = [sys.executable, modname] - p = subprocess.Popen( - cmd_line, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - out, err = p.communicate() - return out, err, p.returncode - - -def test_ExternalCOp_c_code_cache_version(): - """Make sure the C cache versions produced by `ExternalCOp` don't depend on `hash` seeding.""" - - tmp = Path() / ("".join(np.random.choice(list(string.ascii_letters), 8)) + ".py") - tmp.write_bytes(externalcop_test_code.encode()) - - try: - modname = tmp.name - out_1, err1, returncode1 = get_hash(modname, seed=428) - out_2, err2, returncode2 = get_hash(modname, seed=3849) - assert returncode1 == 0 - assert returncode2 == 0 - assert err1 == err2 - - hash_1, msg, _ = out_1.decode().split(os.linesep) - assert msg == "__success__" - hash_2, msg, _ = out_2.decode().split(os.linesep) - assert msg == "__success__" - - assert hash_1 == hash_2 - finally: - tmp.unlink() diff --git a/tests/link/c/test_params_type.py b/tests/link/c/test_params_type.py index d8bd2b754a..ceb69e11b2 100644 --- a/tests/link/c/test_params_type.py +++ b/tests/link/c/test_params_type.py @@ -4,7 +4,7 @@ import pytensor from pytensor import tensor as pt from pytensor.graph.basic import Apply -from pytensor.link.c.op import COp, ExternalCOp +from pytensor.link.c.op import COp from pytensor.link.c.params_type import Params, ParamsType from pytensor.link.c.type import EnumList, Generic from pytensor.scalar import ScalarType @@ -93,31 +93,6 @@ def c_code(self, node, name, inputs, outputs, sub): """ -# Same op as above, but implemented as a ExternalCOp (with C code in an -# external file). -class QuadraticCOpFunc(ExternalCOp): - __props__ = ("a", "b", "c") - params_type = ParamsType(a=tensor_type_0d, b=scalar_type, c=generic_type) - - def __init__(self, a, b, c): - super().__init__( - "c_code/test_quadratic_function.c", "APPLY_SPECIFIC(compute_quadratic)" - ) - self.a = a - self.b = b - self.c = c - - def make_node(self, x): - x = pt.as_tensor_variable(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inputs, output_storage): - coefficients = self.params_type.filter(self.get_params(node)) - x = inputs[0] - y = output_storage[0] - y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c - - class TestParamsType: def test_hash_and_eq_params(self): wp1 = ParamsType( @@ -337,16 +312,12 @@ def test_params_type_with_enums(self): def test_op_params(self): a, b, c = 2, 3, -7 x = matrix(dtype="float64") - y1 = QuadraticOpFunc(a, b, c)(x) - y2 = QuadraticCOpFunc(a, b, c)(x) - f1 = pytensor.function([x], y1, mode="CVM") - f2 = pytensor.function([x], y2, mode="CVM") + y = QuadraticOpFunc(a, b, c)(x) + f = pytensor.function([x], y, mode="CVM") shape = (100, 100) vx = ( np.random.normal(size=shape[0] * shape[1]).astype("float64").reshape(*shape) ) - vy1 = f1(vx) - vy2 = f2(vx) + vy = f(vx) ref = a * (vx**2) + b * vx + c - utt.assert_allclose(vy1, vy2) - utt.assert_allclose(ref, vy1) + utt.assert_allclose(ref, vy) From 15a612ae4cf8f1b8192bf0ef7a514de9328ca573 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 23:38:42 -0500 Subject: [PATCH 5/8] Make OpenMP support probe pure, never mutating global config Replace the class-cached update_self_openmp (which set config.openmp = False process-wide when the compiler lacked OpenMP) with a pure, memoized openmp_supported() probe. self.openmp is now the request; effective use is request AND compiler support, resolved lazily at codegen. --- pytensor/link/c/op.py | 121 +++++++++++++++--------------------- pytensor/tensor/elemwise.py | 11 ++-- tests/link/c/test_op.py | 49 ++++++++++++++- 3 files changed, 104 insertions(+), 77 deletions(-) diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index b0956cf21f..fc3f072cc5 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -1,5 +1,6 @@ import warnings from collections.abc import Callable, Collection +from functools import cache from typing import TYPE_CHECKING, cast from pytensor.configdefaults import config @@ -124,95 +125,73 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): ) -class OpenMPOp(COp): - r"""Base class for `Op`\s using OpenMP. - - This `Op` will check that the compiler support correctly OpenMP code. - If not, it will print a warning and disable OpenMP for this `Op`, then it - will generate the not OpenMP code. +@cache +def openmp_supported() -> bool: + """Return whether the active C compiler can build OpenMP code. - This is needed, as EPD on the Windows version of ``g++`` says it supports - OpenMP, but does not include the OpenMP files. + Memoized; the probe runs at most once per process. It is pure — it never + mutates ``config`` or op state, so the result reflects only the compiler, + not call order. Return ``False`` when there is no C compiler. + """ + if not config.cxx: + return False - We also add the correct compiler flags in ``c_compile_args``. + from pytensor.link.c.cmodule import GCC_compiler + code = """ + #include +int main( int argc, const char* argv[] ) +{ + int res[10]; + for(int i=0; i < 10; i++){ + res[i] = i; + } +} """ + supported = bool( + GCC_compiler.try_compile_tmp( + src_code=code, tmp_prefix="test_omp_", flags=["-fopenmp"], try_run=False + ) + ) + if not supported: + warnings.warn( + "Your C compiler fails to compile OpenMP code; PyTensor will run " + "Elemwise operations single-threaded. Set the `openmp` flag to False " + "to silence this warning.", + stacklevel=2, + ) + return supported - gxx_support_openmp: bool | None = None - """ - ``True``/``False`` after we tested this. +class OpenMPOp(COp): + r"""Base class for `Op`\s using OpenMP. + + A subclass requests OpenMP through the ``openmp`` constructor flag (defaulting + to ``config.openmp``). The request is honored only when `openmp_supported` + confirms the compiler can build it, checked lazily at C-code generation time. """ def __init__(self, openmp: bool | None = None): - if openmp is None: - openmp = config.openmp - self.openmp = openmp + self.openmp = config.openmp if openmp is None else openmp def __setstate__(self, d: dict): self.__dict__.update(d) - # If we unpickle old op + # If we unpickle an old op missing the attribute. if not hasattr(self, "openmp"): self.openmp = False + def _use_openmp(self) -> bool: + """Return whether to emit OpenMP code. + + True when this op requests OpenMP and the compiler supports it. + """ + return self.openmp and openmp_supported() + def c_compile_args(self, **kwargs): - """Return the compilation argument ``"-fopenmp"`` if OpenMP is supported.""" - self.update_self_openmp() - if self.openmp: - return ["-fopenmp"] - return [] + return ["-fopenmp"] if self._use_openmp() else [] def c_headers(self, **kwargs): - """Return the header file name ``"omp.h"`` if OpenMP is supported.""" - self.update_self_openmp() - if self.openmp: - return ["omp.h"] - return [] - - @staticmethod - def test_gxx_support(): - """Check if OpenMP is supported.""" - from pytensor.link.c.cmodule import GCC_compiler - - code = """ - #include -int main( int argc, const char* argv[] ) -{ - int res[10]; - - for(int i=0; i < 10; i++){ - res[i] = i; - } -} - """ - default_openmp = GCC_compiler.try_compile_tmp( - src_code=code, tmp_prefix="test_omp_", flags=["-fopenmp"], try_run=False - ) - return default_openmp - - def update_self_openmp(self) -> None: - """Make sure ``self.openmp`` is not ``True`` if there is no OpenMP support in ``gxx``.""" - if self.openmp: - if OpenMPOp.gxx_support_openmp is None: - OpenMPOp.gxx_support_openmp = OpenMPOp.test_gxx_support() - if not OpenMPOp.gxx_support_openmp: - # We want to warn only once. - warnings.warn( - "Your g++ compiler fails to compile OpenMP code. We" - " know this happen with some version of the EPD mingw" - " compiler and LLVM compiler on Mac OS X." - " We disable openmp everywhere in PyTensor." - " To remove this warning set the pytensor flags `openmp`" - " to False.", - stacklevel=3, - ) - if OpenMPOp.gxx_support_openmp is False: - self.openmp = False - config.openmp = False - - def prepare_node(self, node, storage_map, compute_map, impl): - if impl == "c": - self.update_self_openmp() + return ["omp.h"] if self._use_openmp() else [] class _NoPythonCOp(COp): diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 799fce7eca..670dcf5f0e 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -763,6 +763,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # There is no harm if it get called multiple times. if not hasattr(node.tag, "fake_node"): self.prepare_node(node, None, None, "c") + use_openmp = self._use_openmp() _inames = inames _onames = onames @@ -902,7 +903,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # the index of the last of these aliased outputs. # We generate the C code of the inner loop using the scalar op - if self.openmp: + if use_openmp: # If we are using openmp, we need to get rid of the "goto" # statement in sub['fail']. For now we recreate it here. fail = failure_code(sub, use_goto=False) @@ -980,7 +981,7 @@ def _c_all(self, node, nodename, inames, onames, sub): dtypes=dtypes, loop_tasks=all_code, sub=sub, - openmp=self.openmp, + openmp=use_openmp, ) else: loop = cgen.make_reordered_loop( @@ -989,7 +990,7 @@ def _c_all(self, node, nodename, inames, onames, sub): dtypes=dtypes, inner_task=code, sub=sub, - openmp=self.openmp, + openmp=use_openmp, ) # If all inputs and outputs are contiguous @@ -1047,7 +1048,7 @@ def _c_all(self, node, nodename, inames, onames, sub): contig += f""" dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; """ - if self.openmp: + if use_openmp: contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) """ contig += f""" @@ -1124,7 +1125,7 @@ def c_code_cache_version_apply(self, node): get_scalar_type(dtype=i.type.dtype).c_code_cache_version() for i in node.inputs + node.outputs ) - version.append(("openmp", self.openmp)) + version.append(("openmp", self._use_openmp())) version.append(("openmp_elemwise_minsize", config.openmp_elemwise_minsize)) if all(version): return tuple(version) diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index be731c7687..11d300dbbd 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pytest @@ -6,7 +8,8 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.utils import MethodNotDefined -from pytensor.link.c.op import COp +from pytensor.link.c.cmodule import GCC_compiler +from pytensor.link.c.op import COp, OpenMPOp, openmp_supported class StructOp(COp): @@ -137,3 +140,47 @@ def perform(self, *args, **kwargs): else: with pytest.raises((NotImplementedError, MethodNotDefined)): thunk() + + +class _OpenMPProbeOp(OpenMPOp): + __props__ = () + + def make_node(self, x): + x = ps.as_scalar(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, outputs): + raise NotImplementedError + + +@pytest.fixture +def fresh_openmp_probe(): + openmp_supported.cache_clear() + yield + openmp_supported.cache_clear() + + +@pytest.mark.skipif( + not config.cxx, reason="Requires a C compiler to probe OpenMP support." +) +@pytest.mark.parametrize( + "compiler_supports_openmp", [True, False], ids=["supported", "unsupported"] +) +def test_openmp_resolution_does_not_mutate_global_config( + fresh_openmp_probe, monkeypatch, compiler_supports_openmp +): + monkeypatch.setattr(config, "openmp", True) + monkeypatch.setattr( + GCC_compiler, + "try_compile_tmp", + lambda *args, **kwargs: compiler_supports_openmp, + ) + + op = _OpenMPProbeOp(openmp=True) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + compile_args = op.c_compile_args() + + assert compile_args == (["-fopenmp"] if compiler_supports_openmp else []) + assert op.openmp is True # the op's request survives the compiler's capability + assert config.openmp is True # resolving OpenMP must not flip the global flag From 36025dcb444d91e2412ef4e4b3fb8e8e81386a27 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 13 Jun 2026 23:55:31 -0500 Subject: [PATCH 6/8] Remove OpenMPOp, inlining its remainder into Elemwise Elemwise was the only consumer. It now derives from COp directly and carries the openmp request plus _use_openmp/c_compile_args itself; the dead omp.h header path (always shadowed by Elemwise.c_headers) is dropped. --- pytensor/link/c/op.py | 31 ------------------------------- pytensor/tensor/elemwise.py | 20 ++++++++++++++++---- tests/link/c/test_op.py | 16 +++------------- 3 files changed, 19 insertions(+), 48 deletions(-) diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index fc3f072cc5..98121df20b 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -163,37 +163,6 @@ def openmp_supported() -> bool: return supported -class OpenMPOp(COp): - r"""Base class for `Op`\s using OpenMP. - - A subclass requests OpenMP through the ``openmp`` constructor flag (defaulting - to ``config.openmp``). The request is honored only when `openmp_supported` - confirms the compiler can build it, checked lazily at C-code generation time. - """ - - def __init__(self, openmp: bool | None = None): - self.openmp = config.openmp if openmp is None else openmp - - def __setstate__(self, d: dict): - self.__dict__.update(d) - # If we unpickle an old op missing the attribute. - if not hasattr(self, "openmp"): - self.openmp = False - - def _use_openmp(self) -> bool: - """Return whether to emit OpenMP code. - - True when this op requests OpenMP and the compiler supports it. - """ - return self.openmp and openmp_supported() - - def c_compile_args(self, **kwargs): - return ["-fopenmp"] if self._use_openmp() else [] - - def c_headers(self, **kwargs): - return ["omp.h"] if self._use_openmp() else [] - - class _NoPythonCOp(COp): """A class used to indicate that a `COp` does not provide a Python implementation. diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 670dcf5f0e..7a834ca6e3 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -15,7 +15,7 @@ from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code -from pytensor.link.c.op import COp, OpenMPOp +from pytensor.link.c.op import COp, openmp_supported from pytensor.misc.frozendict import frozendict from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type @@ -298,7 +298,7 @@ def process(self, r, pstate): pprint.assign(DimShuffle, DimShufflePrinter()) -class Elemwise(OpenMPOp): +class Elemwise(COp): """Generalizes a scalar `Op` to tensors. All the inputs must have the same number of dimensions. When the @@ -368,7 +368,7 @@ def __init__( nfunc_spec = getattr(scalar_op, "nfunc_spec", None) self.nfunc_spec = nfunc_spec self.__setstate__(self.__dict__) - super().__init__(openmp=openmp) + self.openmp = config.openmp if openmp is None else openmp def __getstate__(self): d = copy(self.__dict__) @@ -378,11 +378,23 @@ def __getstate__(self): return d def __setstate__(self, d): - super().__setstate__(d) + self.__dict__.update(d) + if not hasattr(self, "openmp"): + self.openmp = False self.ufunc = None self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) + def _use_openmp(self) -> bool: + """Return whether to emit OpenMP code. + + True when this op requests OpenMP and the compiler supports it. + """ + return self.openmp and openmp_supported() + + def c_compile_args(self, **kwargs): + return ["-fopenmp"] if self._use_openmp() else [] + def make_scalar_node(self, *inputs): """Create a scalar Apply node matching the dtypes of tensor inputs. diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index 11d300dbbd..16aacd88c1 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -9,7 +9,8 @@ from pytensor.graph.basic import Apply from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.cmodule import GCC_compiler -from pytensor.link.c.op import COp, OpenMPOp, openmp_supported +from pytensor.link.c.op import COp, openmp_supported +from pytensor.tensor.elemwise import Elemwise class StructOp(COp): @@ -142,17 +143,6 @@ def perform(self, *args, **kwargs): thunk() -class _OpenMPProbeOp(OpenMPOp): - __props__ = () - - def make_node(self, x): - x = ps.as_scalar(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inputs, outputs): - raise NotImplementedError - - @pytest.fixture def fresh_openmp_probe(): openmp_supported.cache_clear() @@ -176,7 +166,7 @@ def test_openmp_resolution_does_not_mutate_global_config( lambda *args, **kwargs: compiler_supports_openmp, ) - op = _OpenMPProbeOp(openmp=True) + op = Elemwise(ps.add, openmp=True) with warnings.catch_warnings(): warnings.simplefilter("ignore") compile_args = op.c_compile_args() From 6bdc7ce2a3acea11b16d09e882dbe06e66d9ccfd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 14 Jun 2026 00:27:49 -0500 Subject: [PATCH 7/8] Migrate Elemwise to the C dispatch registry Extract _c_all into a module-level _elemwise_c_all generator, register ElemwiseImpl, and make Elemwise a plain Op with no C methods. CAReduce calls the generator directly; the openmp decision moves from the op to the impl. --- pytensor/link/c/dispatch/elemwise.py | 81 ++- pytensor/tensor/elemwise.py | 710 ++++++++++++--------------- tests/link/c/test_dispatch.py | 13 +- tests/link/c/test_op.py | 4 +- 4 files changed, 413 insertions(+), 395 deletions(-) diff --git a/pytensor/link/c/dispatch/elemwise.py b/pytensor/link/c/dispatch/elemwise.py index 1fb112bcdb..b949e7d95e 100644 --- a/pytensor/link/c/dispatch/elemwise.py +++ b/pytensor/link/c/dispatch/elemwise.py @@ -1,8 +1,11 @@ from collections.abc import Hashable +from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.link.c.dispatch.basic import CImpl, c_funcify -from pytensor.tensor.elemwise import DimShuffle +from pytensor.link.c.op import openmp_supported +from pytensor.scalar import get_scalar_type +from pytensor.tensor.elemwise import DimShuffle, Elemwise, _elemwise_c_all class DimShuffleImpl(CImpl): @@ -112,3 +115,79 @@ def c_code( @c_funcify.register(DimShuffle) def c_funcify_dimshuffle(op, node=None, **kwargs) -> DimShuffleImpl: return DimShuffleImpl(op) + + +class ElemwiseImpl(CImpl): + """C implementation of `Elemwise`. + + Emits the broadcasting loop over the per-element scalar code via + `_elemwise_c_all`, delegating support code and headers to the scalar op. + """ + + op: Elemwise + + def _use_openmp(self) -> bool: + """Return whether to emit OpenMP: requested by the op and compiler-supported.""" + return self.op.openmp and openmp_supported() + + def c_support_code(self, **kwargs) -> str: + return self.op.scalar_op.c_support_code(**kwargs) + + def c_support_code_apply(self, node: Apply, name: str) -> str: + return self.op.scalar_op.c_support_code_apply(node, name + "_scalar_") + + def c_headers(self, **kwargs) -> list[str]: + return ["", ""] + + def c_header_dirs(self, **kwargs) -> list[str]: + return self.op.scalar_op.c_header_dirs(**kwargs) + + def c_compile_args(self, **kwargs) -> list[str]: + return ["-fopenmp"] if self._use_openmp() else [] + + def c_code( + self, + node: Apply, + name: str, + inputs: list[str], + outputs: list[str], + sub: dict[str, str], + ) -> str: + scalar_op = self.op.scalar_op + if ( + any(i.dtype == "float16" for i in node.inputs) + or any(o.dtype == "float16" for o in node.outputs) + # This is for Composite + or getattr(scalar_op, "inner_float16", False) + ): + # No float16 C support; fall back to perform. + raise NotImplementedError() + return "\n".join( + _elemwise_c_all( + self.op, node, name, inputs, outputs, sub, self._use_openmp() + ) + ) + + def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: + scalar_op = self.op.scalar_op + version = [17] # bump when the emitted C changes + scalar_node = Apply( + scalar_op, + [get_scalar_type(dtype=i.type.dtype).make_variable() for i in node.inputs], + [get_scalar_type(dtype=o.type.dtype).make_variable() for o in node.outputs], + ) + version.append(scalar_op.c_code_cache_version_apply(scalar_node)) + version.extend( + get_scalar_type(dtype=i.type.dtype).c_code_cache_version() + for i in node.inputs + node.outputs + ) + version.append(("openmp", self._use_openmp())) + version.append(("openmp_elemwise_minsize", config.openmp_elemwise_minsize)) + if all(version): + return tuple(version) + return () + + +@c_funcify.register(Elemwise) +def c_funcify_elemwise(op, node=None, **kwargs) -> ElemwiseImpl: + return ElemwiseImpl(op) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 7a834ca6e3..9259539dae 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -298,7 +298,319 @@ def process(self, r, pstate): pprint.assign(DimShuffle, DimShufflePrinter()) -class Elemwise(COp): +def _elemwise_c_all(op, node, nodename, inames, onames, sub, use_openmp): + # Some `Op`s directly call this generator on a fresh Elemwise. + # To not request all of them to call prepare_node(), do it here. + # There is no harm if it get called multiple times. + if not hasattr(node.tag, "fake_node"): + op.prepare_node(node, None, None, "c") + _inames = inames + _onames = onames + + inames = uniq(inames) + inputs = uniq(node.inputs) + # assert that inames and inputs order stay consistent. + # This is to protect again futur change of uniq. + assert len(inames) == len(inputs) + ii, iii = unzip( + uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True + ) + assert all(x == y for x, y in zip(ii, inames, strict=True)) + assert all(x == y for x, y in zip(iii, inputs, strict=True)) + + defines = "" + undefs = "" + + # The destroy map is a map of output indices to input indices + # that overwrite them. We just convert them to the actual + # Variables. + dmap = {node.outputs[o]: [node.inputs[i]] for o, i in op.inplace_pattern.items()} + + # dtypes of the inputs + idtypes = [input.type.dtype_specs()[1] for input in inputs] + + # These are the outputs that we will need to allocate + # (output, name, name of the c type), transposed + real = list( + zip( + *[ + (r, s, r.type.dtype_specs()[1]) + for r, s in zip(node.outputs, onames, strict=True) + if r not in dmap + ], + strict=True, + ) + ) + if real: + real_outputs, real_onames, real_odtypes = real + else: + real_outputs, real_onames, real_odtypes = [], [], [] + + # Outputs that are aliased with an input (inplace) + # (output, name), transposed (c type name not needed since we don't + # need to allocate. + aliased = list( + zip( + *[(r, s) for (r, s) in zip(node.outputs, onames, strict=True) if r in dmap], + strict=True, + ) + ) + if aliased: + aliased_outputs, aliased_onames = aliased + else: + aliased_outputs, aliased_onames = [], [] + + # for each input: + # same as range(ndim), but with 'x' at all broadcastable positions + orders = [ + [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)] + for input in inputs + ] + + # number of nested loops we will need (all inputs have same + # dimensionality) + nnested = len(orders[0]) + sub = dict(sub) + for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): + # the c generators will substitute the input names for + # references to loop variables lv0, lv1, ... + sub[f"lv{i}"] = iname + + decl = cgen.make_declare(orders, idtypes, sub) + checks = cgen.make_checks(orders, idtypes, sub) + + # Check if all inputs (except broadcasted scalar) are fortran. + # In that case, create a fortran output ndarray. + z = list(zip(inames, inputs, strict=True)) + alloc_fortran = " && ".join( + f"PyArray_ISFORTRAN({arr})" + for arr, var in z + if not all(s == 1 for s in var.type.shape) + ) + # If it is a scalar, make it c contig to prevent problem with + # NumPy C and F contig not always set as both of them. + if len(alloc_fortran) == 0: + alloc_fortran = "0" + + alloc = "" + # We loop over the "real" outputs, i.e., those that are not + # inplace (must be allocated) and we declare/allocate/check + # them + for output, oname, odtype in zip( + real_outputs, real_onames, real_odtypes, strict=True + ): + i += 1 # before this loop, i = number of inputs + sub[f"lv{i}"] = oname + sub["olv"] = oname + alloc += cgen.make_declare( + [list(range(nnested))], [odtype], dict(sub, lv0=oname) + ) + alloc += cgen.make_alloc(orders, odtype, sub, fortran=alloc_fortran) + alloc += cgen.make_checks( + [list(range(nnested))], [odtype], dict(sub, lv0=oname) + ) + olv_index = i # index of the last output + + # We loop over the "aliased" outputs, i.e., those that are + # inplace (overwrite the contents of one of the inputs) and + # make the output pointers point to their corresponding input + # pointers. + for output, oname in zip(aliased_outputs, aliased_onames, strict=True): + olv_index = inputs.index(dmap[output][0]) + iname = inames[olv_index] + # We make the output point to the corresponding input and + # decrease the reference of whatever the output contained + # prior to this + alloc += f""" + if ({oname}) {{ + Py_XDECREF({oname}); + }} + {oname} = {iname}; + Py_XINCREF({oname}); + """ + # We alias the scalar variables + defines += f"#define {oname}_i {iname}_i\n" + undefs += f"#undef {oname}_i\n" + + # Note: here, olv_index is either the index of the last output + # which is allocated, OR, if there are any aliased outputs, + # the index of the last of these aliased outputs. + + # We generate the C code of the inner loop using the scalar op + if use_openmp: + # If we are using openmp, we need to get rid of the "goto" + # statement in sub['fail']. For now we recreate it here. + fail = failure_code(sub, use_goto=False) + else: + fail = sub["fail"] + task_code = op.scalar_op.c_code( + node.tag.fake_node, + nodename + "_scalar_", + [f"{s}_i" for s in _inames], + [f"{s}_i" for s in onames], + dict(sub, fail=fail), + ) + code = f""" + {{ + {defines} + {task_code} + {undefs} + }} + """ + + loop_orders = orders + [list(range(nnested))] * len(real_onames) + dtypes = idtypes + list(real_odtypes) + if all( + [o.ndim <= 1 for o in node.outputs] + or + # Use simpler code when output ndim == 0 or 1 + # or for broadcated scalar. + all(s == 1 for s in node.outputs[0].type.shape) + ): + if nnested: + all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] + else: + all_code = [code] + if len(all_code) == 1: + # No loops + task_decl = "".join( + f"{dtype}& {name}_i = *{name}_iter;\n" + for name, dtype in zip( + inames + list(real_onames), + idtypes + list(real_odtypes), + strict=True, + ) + ) + + preloops = {} + for i, (loop_order, dtype) in enumerate( + zip(loop_orders, dtypes, strict=True) + ): + for j, index in enumerate(loop_order): + if index != "x": + preloops.setdefault(j, "") + preloops[j] += ( + f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" + ) % sub + break + else: # all broadcastable + preloops.setdefault(0, "") + preloops[0] += ( + f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" + ) % sub + + init_array = preloops.get(0, " ") + loop = f""" + {{ + {defines} + {init_array} + {task_decl} + {task_code} + {undefs} + }} + """ + else: + loop = cgen.make_loop( + loop_orders=loop_orders, + dtypes=dtypes, + loop_tasks=all_code, + sub=sub, + openmp=use_openmp, + ) + else: + loop = cgen.make_reordered_loop( + init_loop_orders=loop_orders, + olv_index=olv_index, + dtypes=dtypes, + inner_task=code, + sub=sub, + openmp=use_openmp, + ) + + # If all inputs and outputs are contiguous + # and the scalar op define optimized code for that case + # use it! The scalar_op needs to check the type-level shapes itself. + if ( + all(o.ndim >= 1 for o in node.outputs) + and + # Don't use the contig code for broadcasted scalar. + not all(s == 1 for s in node.outputs[0].type.shape) + ): + contig = None + try: + contig = op.scalar_op.c_code_contiguous( + node, nodename + "_scalar_contig_", _inames, onames, sub + ) + except MethodNotDefined: + # Try to make one generic version, this will help the + # compiler to vectorize the code as their won't be as + # many ptr and the stride will be hard coded. + if all( + # io.type.shape == node.outputs[1].type.shape + # Elemwise does not specify non-broadcastable static/type-levelshape + # information for its outputs yet + node.outputs[0].type.is_super(io.type) + for io in node.inputs + node.outputs + ) and ( + len(node.inputs) <= 1 + # If either one of the inputs has a `None` shape, we cannot + # assume they will have the same size + or all( + len(set(inp_shape)) == 1 and None not in inp_shape + for inp_shape in zip( + *(inp.type.shape for inp in node.inputs), strict=True + ) + ) + ): + z = onames[0] + contig = f""" + // All output have the same size + npy_intp n = PyArray_SIZE({z}); + """ + index = "" + for x, var in zip(inames + onames, inputs + node.outputs, strict=True): + if not all(s == 1 for s in var.type.shape): + contig += f""" + dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); + """ + index += f""" + dtype_{x}& {x}_i = {x}_ptr[i]; + """ + else: + contig += f""" + dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; + """ + if use_openmp: + contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) + """ + contig += f""" + for(int i=0; i bool: - """Return whether to emit OpenMP code. - - True when this op requests OpenMP and the compiler supports it. - """ - return self.openmp and openmp_supported() - - def c_compile_args(self, **kwargs): - return ["-fopenmp"] if self._use_openmp() else [] - def make_scalar_node(self, *inputs): """Create a scalar Apply node matching the dtypes of tensor inputs. @@ -769,381 +1071,6 @@ def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...] out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs) - def _c_all(self, node, nodename, inames, onames, sub): - # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` - # To not request all of them to call prepare_node(), do it here. - # There is no harm if it get called multiple times. - if not hasattr(node.tag, "fake_node"): - self.prepare_node(node, None, None, "c") - use_openmp = self._use_openmp() - _inames = inames - _onames = onames - - inames = uniq(inames) - inputs = uniq(node.inputs) - # assert that inames and inputs order stay consistent. - # This is to protect again futur change of uniq. - assert len(inames) == len(inputs) - ii, iii = unzip( - uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True - ) - assert all(x == y for x, y in zip(ii, inames, strict=True)) - assert all(x == y for x, y in zip(iii, inputs, strict=True)) - - defines = "" - undefs = "" - - # The destroy map is a map of output indices to input indices - # that overwrite them. We just convert them to the actual - # Variables. - dmap = { - node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern.items() - } - - # dtypes of the inputs - idtypes = [input.type.dtype_specs()[1] for input in inputs] - - # These are the outputs that we will need to allocate - # (output, name, name of the c type), transposed - real = list( - zip( - *[ - (r, s, r.type.dtype_specs()[1]) - for r, s in zip(node.outputs, onames, strict=True) - if r not in dmap - ], - strict=True, - ) - ) - if real: - real_outputs, real_onames, real_odtypes = real - else: - real_outputs, real_onames, real_odtypes = [], [], [] - - # Outputs that are aliased with an input (inplace) - # (output, name), transposed (c type name not needed since we don't - # need to allocate. - aliased = list( - zip( - *[ - (r, s) - for (r, s) in zip(node.outputs, onames, strict=True) - if r in dmap - ], - strict=True, - ) - ) - if aliased: - aliased_outputs, aliased_onames = aliased - else: - aliased_outputs, aliased_onames = [], [] - - # for each input: - # same as range(ndim), but with 'x' at all broadcastable positions - orders = [ - [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)] - for input in inputs - ] - - # number of nested loops we will need (all inputs have same - # dimensionality) - nnested = len(orders[0]) - sub = dict(sub) - for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): - # the c generators will substitute the input names for - # references to loop variables lv0, lv1, ... - sub[f"lv{i}"] = iname - - decl = cgen.make_declare(orders, idtypes, sub) - checks = cgen.make_checks(orders, idtypes, sub) - - # Check if all inputs (except broadcasted scalar) are fortran. - # In that case, create a fortran output ndarray. - z = list(zip(inames, inputs, strict=True)) - alloc_fortran = " && ".join( - f"PyArray_ISFORTRAN({arr})" - for arr, var in z - if not all(s == 1 for s in var.type.shape) - ) - # If it is a scalar, make it c contig to prevent problem with - # NumPy C and F contig not always set as both of them. - if len(alloc_fortran) == 0: - alloc_fortran = "0" - - alloc = "" - # We loop over the "real" outputs, i.e., those that are not - # inplace (must be allocated) and we declare/allocate/check - # them - for output, oname, odtype in zip( - real_outputs, real_onames, real_odtypes, strict=True - ): - i += 1 # before this loop, i = number of inputs - sub[f"lv{i}"] = oname - sub["olv"] = oname - alloc += cgen.make_declare( - [list(range(nnested))], [odtype], dict(sub, lv0=oname) - ) - alloc += cgen.make_alloc(orders, odtype, sub, fortran=alloc_fortran) - alloc += cgen.make_checks( - [list(range(nnested))], [odtype], dict(sub, lv0=oname) - ) - olv_index = i # index of the last output - - # We loop over the "aliased" outputs, i.e., those that are - # inplace (overwrite the contents of one of the inputs) and - # make the output pointers point to their corresponding input - # pointers. - for output, oname in zip(aliased_outputs, aliased_onames, strict=True): - olv_index = inputs.index(dmap[output][0]) - iname = inames[olv_index] - # We make the output point to the corresponding input and - # decrease the reference of whatever the output contained - # prior to this - alloc += f""" - if ({oname}) {{ - Py_XDECREF({oname}); - }} - {oname} = {iname}; - Py_XINCREF({oname}); - """ - # We alias the scalar variables - defines += f"#define {oname}_i {iname}_i\n" - undefs += f"#undef {oname}_i\n" - - # Note: here, olv_index is either the index of the last output - # which is allocated, OR, if there are any aliased outputs, - # the index of the last of these aliased outputs. - - # We generate the C code of the inner loop using the scalar op - if use_openmp: - # If we are using openmp, we need to get rid of the "goto" - # statement in sub['fail']. For now we recreate it here. - fail = failure_code(sub, use_goto=False) - else: - fail = sub["fail"] - task_code = self.scalar_op.c_code( - node.tag.fake_node, - nodename + "_scalar_", - [f"{s}_i" for s in _inames], - [f"{s}_i" for s in onames], - dict(sub, fail=fail), - ) - code = f""" - {{ - {defines} - {task_code} - {undefs} - }} - """ - - loop_orders = orders + [list(range(nnested))] * len(real_onames) - dtypes = idtypes + list(real_odtypes) - if all( - [o.ndim <= 1 for o in node.outputs] - or - # Use simpler code when output ndim == 0 or 1 - # or for broadcated scalar. - all(s == 1 for s in node.outputs[0].type.shape) - ): - if nnested: - all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] - else: - all_code = [code] - if len(all_code) == 1: - # No loops - task_decl = "".join( - f"{dtype}& {name}_i = *{name}_iter;\n" - for name, dtype in zip( - inames + list(real_onames), - idtypes + list(real_odtypes), - strict=True, - ) - ) - - preloops = {} - for i, (loop_order, dtype) in enumerate( - zip(loop_orders, dtypes, strict=True) - ): - for j, index in enumerate(loop_order): - if index != "x": - preloops.setdefault(j, "") - preloops[j] += ( - f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" - ) % sub - break - else: # all broadcastable - preloops.setdefault(0, "") - preloops[0] += ( - f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" - ) % sub - - init_array = preloops.get(0, " ") - loop = f""" - {{ - {defines} - {init_array} - {task_decl} - {task_code} - {undefs} - }} - """ - else: - loop = cgen.make_loop( - loop_orders=loop_orders, - dtypes=dtypes, - loop_tasks=all_code, - sub=sub, - openmp=use_openmp, - ) - else: - loop = cgen.make_reordered_loop( - init_loop_orders=loop_orders, - olv_index=olv_index, - dtypes=dtypes, - inner_task=code, - sub=sub, - openmp=use_openmp, - ) - - # If all inputs and outputs are contiguous - # and the scalar op define optimized code for that case - # use it! The scalar_op needs to check the type-level shapes itself. - if ( - all(o.ndim >= 1 for o in node.outputs) - and - # Don't use the contig code for broadcasted scalar. - not all(s == 1 for s in node.outputs[0].type.shape) - ): - contig = None - try: - contig = self.scalar_op.c_code_contiguous( - node, nodename + "_scalar_contig_", _inames, onames, sub - ) - except MethodNotDefined: - # Try to make one generic version, this will help the - # compiler to vectorize the code as their won't be as - # many ptr and the stride will be hard coded. - if all( - # io.type.shape == node.outputs[1].type.shape - # Elemwise does not specify non-broadcastable static/type-levelshape - # information for its outputs yet - node.outputs[0].type.is_super(io.type) - for io in node.inputs + node.outputs - ) and ( - len(node.inputs) <= 1 - # If either one of the inputs has a `None` shape, we cannot - # assume they will have the same size - or all( - len(set(inp_shape)) == 1 and None not in inp_shape - for inp_shape in zip( - *(inp.type.shape for inp in node.inputs), strict=True - ) - ) - ): - z = onames[0] - contig = f""" - // All output have the same size - npy_intp n = PyArray_SIZE({z}); - """ - index = "" - for x, var in zip( - inames + onames, inputs + node.outputs, strict=True - ): - if not all(s == 1 for s in var.type.shape): - contig += f""" - dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); - """ - index += f""" - dtype_{x}& {x}_i = {x}_ptr[i]; - """ - else: - contig += f""" - dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; - """ - if use_openmp: - contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) - """ - contig += f""" - for(int i=0; i", ""] - - def c_header_dirs(self, **kwargs): - return self.scalar_op.c_header_dirs(**kwargs) - - def c_support_code(self, **kwargs): - return self.scalar_op.c_support_code(**kwargs) - - def c_support_code_apply(self, node, nodename): - support_code = self.scalar_op.c_support_code_apply(node, nodename + "_scalar_") - return support_code - - def c_code_cache_version_apply(self, node): - version = [17] # the version corresponding to the c code in this Op - - # now we insert versions for the ops on which we depend... - scalar_node = Apply( - self.scalar_op, - [ - get_scalar_type(dtype=input.type.dtype).make_variable() - for input in node.inputs - ], - [ - get_scalar_type(dtype=output.type.dtype).make_variable() - for output in node.outputs - ], - ) - version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) - version.extend( - get_scalar_type(dtype=i.type.dtype).c_code_cache_version() - for i in node.inputs + node.outputs - ) - version.append(("openmp", self._use_openmp())) - version.append(("openmp_elemwise_minsize", config.openmp_elemwise_minsize)) - if all(version): - return tuple(version) - else: - return () - def outer(self, x, y): from pytensor.tensor.basic import expand_dims @@ -1475,7 +1402,16 @@ def _c_all(self, node, name, input_names, output_names, sub): if var is inp: var = Elemwise(scalar_identity)(inp) assert var.dtype == node.outputs[0].dtype - return var.owner.op._c_all(var.owner, name, input_names, output_names, sub) + inner_op = var.owner.op + return _elemwise_c_all( + inner_op, + var.owner, + name, + input_names, + output_names, + sub, + use_openmp=inner_op.openmp and openmp_supported(), + ) inp_dims = list(range(ndim)) non_reduced_dims = [i for i in inp_dims if i not in axis] diff --git a/tests/link/c/test_dispatch.py b/tests/link/c/test_dispatch.py index 8f78551c3f..ed2b808a6b 100644 --- a/tests/link/c/test_dispatch.py +++ b/tests/link/c/test_dispatch.py @@ -235,12 +235,10 @@ def c_funcify_wrong(op, node=None, **kwargs): def test_cop_graph_resolves_to_identity(): - # The parity guarantee: every COp node resolves to itself, so CLinker calls - # the op's own c_code/cache-version methods and produces byte-identical - # source and cache keys. A registered op (DimShuffle) resolves to its - # detached impl instead. - from pytensor.link.c.dispatch.elemwise import DimShuffleImpl - from pytensor.tensor.elemwise import DimShuffle + # An unregistered COp node resolves to itself, so CLinker calls the op's own + # c_code/cache-version methods. A registered op resolves to its detached impl. + from pytensor.link.c.dispatch.elemwise import DimShuffleImpl, ElemwiseImpl + from pytensor.tensor.elemwise import DimShuffle, Elemwise x = pt.matrix("x") out = (x.T + 1.0).sum(axis=0) @@ -251,7 +249,10 @@ def test_cop_graph_resolves_to_identity(): impl = cl._impl_for(node) if isinstance(node.op, DimShuffle): assert isinstance(impl, DimShuffleImpl) + elif isinstance(node.op, Elemwise): + assert isinstance(impl, ElemwiseImpl) else: + # A genuine COp (e.g. the CAReduce sum) resolves to itself. assert impl is node.op # Source generation works and the module is versioned (cacheable). diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index 16aacd88c1..bc6b2a5516 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -9,6 +9,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.cmodule import GCC_compiler +from pytensor.link.c.dispatch.basic import c_funcify from pytensor.link.c.op import COp, openmp_supported from pytensor.tensor.elemwise import Elemwise @@ -167,9 +168,10 @@ def test_openmp_resolution_does_not_mutate_global_config( ) op = Elemwise(ps.add, openmp=True) + impl = c_funcify(op) with warnings.catch_warnings(): warnings.simplefilter("ignore") - compile_args = op.c_compile_args() + compile_args = impl.c_compile_args() assert compile_args == (["-fopenmp"] if compiler_supports_openmp else []) assert op.openmp is True # the op's request survives the compiler's capability From db46dd83045fb8c80f834757625760e8d4e68ed6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 14 Jun 2026 00:56:51 -0500 Subject: [PATCH 8/8] Migrate CAReduce to the C dispatch registry CAReduce becomes a plain Op; NonZeroDimsCAReduce's _c_all override becomes an error_on_empty_reduce_axis flag the generator reads. The elemwise and careduce C generators move into link.c.dispatch.elemwise beside the impls, leaving tensor/elemwise.py free of any C codegen. --- pytensor/link/c/dispatch/elemwise.py | 575 ++++++++++++++++++++++++++- pytensor/tensor/elemwise.py | 528 +----------------------- pytensor/tensor/math.py | 34 +- tests/link/c/test_dispatch.py | 18 +- 4 files changed, 588 insertions(+), 567 deletions(-) diff --git a/pytensor/link/c/dispatch/elemwise.py b/pytensor/link/c/dispatch/elemwise.py index b949e7d95e..0c2d1139fd 100644 --- a/pytensor/link/c/dispatch/elemwise.py +++ b/pytensor/link/c/dispatch/elemwise.py @@ -1,11 +1,524 @@ from collections.abc import Hashable +from textwrap import dedent +from typing import cast +import numpy as np + +import pytensor.tensor.basic from pytensor.configdefaults import config from pytensor.graph.basic import Apply +from pytensor.graph.utils import MethodNotDefined +from pytensor.link.c.basic import failure_code from pytensor.link.c.dispatch.basic import CImpl, c_funcify from pytensor.link.c.op import openmp_supported from pytensor.scalar import get_scalar_type -from pytensor.tensor.elemwise import DimShuffle, Elemwise, _elemwise_c_all +from pytensor.scalar.basic import identity as scalar_identity +from pytensor.tensor import elemwise_cgen as cgen +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.type import TensorType +from pytensor.utils import uniq, unzip + + +def _elemwise_c_all(op, node, nodename, inames, onames, sub, use_openmp): + # Some `Op`s directly call this generator on a fresh Elemwise. + # To not request all of them to call prepare_node(), do it here. + # There is no harm if it get called multiple times. + if not hasattr(node.tag, "fake_node"): + op.prepare_node(node, None, None, "c") + _inames = inames + _onames = onames + + inames = uniq(inames) + inputs = uniq(node.inputs) + # assert that inames and inputs order stay consistent. + # This is to protect again futur change of uniq. + assert len(inames) == len(inputs) + ii, iii = unzip( + uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True + ) + assert all(x == y for x, y in zip(ii, inames, strict=True)) + assert all(x == y for x, y in zip(iii, inputs, strict=True)) + + defines = "" + undefs = "" + + # The destroy map is a map of output indices to input indices + # that overwrite them. We just convert them to the actual + # Variables. + dmap = {node.outputs[o]: [node.inputs[i]] for o, i in op.inplace_pattern.items()} + + # dtypes of the inputs + idtypes = [input.type.dtype_specs()[1] for input in inputs] + + # These are the outputs that we will need to allocate + # (output, name, name of the c type), transposed + real = list( + zip( + *[ + (r, s, r.type.dtype_specs()[1]) + for r, s in zip(node.outputs, onames, strict=True) + if r not in dmap + ], + strict=True, + ) + ) + if real: + real_outputs, real_onames, real_odtypes = real + else: + real_outputs, real_onames, real_odtypes = [], [], [] + + # Outputs that are aliased with an input (inplace) + # (output, name), transposed (c type name not needed since we don't + # need to allocate. + aliased = list( + zip( + *[(r, s) for (r, s) in zip(node.outputs, onames, strict=True) if r in dmap], + strict=True, + ) + ) + if aliased: + aliased_outputs, aliased_onames = aliased + else: + aliased_outputs, aliased_onames = [], [] + + # for each input: + # same as range(ndim), but with 'x' at all broadcastable positions + orders = [ + [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)] + for input in inputs + ] + + # number of nested loops we will need (all inputs have same + # dimensionality) + nnested = len(orders[0]) + sub = dict(sub) + for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): + # the c generators will substitute the input names for + # references to loop variables lv0, lv1, ... + sub[f"lv{i}"] = iname + + decl = cgen.make_declare(orders, idtypes, sub) + checks = cgen.make_checks(orders, idtypes, sub) + + # Check if all inputs (except broadcasted scalar) are fortran. + # In that case, create a fortran output ndarray. + z = list(zip(inames, inputs, strict=True)) + alloc_fortran = " && ".join( + f"PyArray_ISFORTRAN({arr})" + for arr, var in z + if not all(s == 1 for s in var.type.shape) + ) + # If it is a scalar, make it c contig to prevent problem with + # NumPy C and F contig not always set as both of them. + if len(alloc_fortran) == 0: + alloc_fortran = "0" + + alloc = "" + # We loop over the "real" outputs, i.e., those that are not + # inplace (must be allocated) and we declare/allocate/check + # them + for output, oname, odtype in zip( + real_outputs, real_onames, real_odtypes, strict=True + ): + i += 1 # before this loop, i = number of inputs + sub[f"lv{i}"] = oname + sub["olv"] = oname + alloc += cgen.make_declare( + [list(range(nnested))], [odtype], dict(sub, lv0=oname) + ) + alloc += cgen.make_alloc(orders, odtype, sub, fortran=alloc_fortran) + alloc += cgen.make_checks( + [list(range(nnested))], [odtype], dict(sub, lv0=oname) + ) + olv_index = i # index of the last output + + # We loop over the "aliased" outputs, i.e., those that are + # inplace (overwrite the contents of one of the inputs) and + # make the output pointers point to their corresponding input + # pointers. + for output, oname in zip(aliased_outputs, aliased_onames, strict=True): + olv_index = inputs.index(dmap[output][0]) + iname = inames[olv_index] + # We make the output point to the corresponding input and + # decrease the reference of whatever the output contained + # prior to this + alloc += f""" + if ({oname}) {{ + Py_XDECREF({oname}); + }} + {oname} = {iname}; + Py_XINCREF({oname}); + """ + # We alias the scalar variables + defines += f"#define {oname}_i {iname}_i\n" + undefs += f"#undef {oname}_i\n" + + # Note: here, olv_index is either the index of the last output + # which is allocated, OR, if there are any aliased outputs, + # the index of the last of these aliased outputs. + + # We generate the C code of the inner loop using the scalar op + if use_openmp: + # If we are using openmp, we need to get rid of the "goto" + # statement in sub['fail']. For now we recreate it here. + fail = failure_code(sub, use_goto=False) + else: + fail = sub["fail"] + task_code = op.scalar_op.c_code( + node.tag.fake_node, + nodename + "_scalar_", + [f"{s}_i" for s in _inames], + [f"{s}_i" for s in onames], + dict(sub, fail=fail), + ) + code = f""" + {{ + {defines} + {task_code} + {undefs} + }} + """ + + loop_orders = orders + [list(range(nnested))] * len(real_onames) + dtypes = idtypes + list(real_odtypes) + if all( + [o.ndim <= 1 for o in node.outputs] + or + # Use simpler code when output ndim == 0 or 1 + # or for broadcated scalar. + all(s == 1 for s in node.outputs[0].type.shape) + ): + if nnested: + all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] + else: + all_code = [code] + if len(all_code) == 1: + # No loops + task_decl = "".join( + f"{dtype}& {name}_i = *{name}_iter;\n" + for name, dtype in zip( + inames + list(real_onames), + idtypes + list(real_odtypes), + strict=True, + ) + ) + + preloops = {} + for i, (loop_order, dtype) in enumerate( + zip(loop_orders, dtypes, strict=True) + ): + for j, index in enumerate(loop_order): + if index != "x": + preloops.setdefault(j, "") + preloops[j] += ( + f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" + ) % sub + break + else: # all broadcastable + preloops.setdefault(0, "") + preloops[0] += ( + f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" + ) % sub + + init_array = preloops.get(0, " ") + loop = f""" + {{ + {defines} + {init_array} + {task_decl} + {task_code} + {undefs} + }} + """ + else: + loop = cgen.make_loop( + loop_orders=loop_orders, + dtypes=dtypes, + loop_tasks=all_code, + sub=sub, + openmp=use_openmp, + ) + else: + loop = cgen.make_reordered_loop( + init_loop_orders=loop_orders, + olv_index=olv_index, + dtypes=dtypes, + inner_task=code, + sub=sub, + openmp=use_openmp, + ) + + # If all inputs and outputs are contiguous + # and the scalar op define optimized code for that case + # use it! The scalar_op needs to check the type-level shapes itself. + if ( + all(o.ndim >= 1 for o in node.outputs) + and + # Don't use the contig code for broadcasted scalar. + not all(s == 1 for s in node.outputs[0].type.shape) + ): + contig = None + try: + contig = op.scalar_op.c_code_contiguous( + node, nodename + "_scalar_contig_", _inames, onames, sub + ) + except MethodNotDefined: + # Try to make one generic version, this will help the + # compiler to vectorize the code as their won't be as + # many ptr and the stride will be hard coded. + if all( + # io.type.shape == node.outputs[1].type.shape + # Elemwise does not specify non-broadcastable static/type-levelshape + # information for its outputs yet + node.outputs[0].type.is_super(io.type) + for io in node.inputs + node.outputs + ) and ( + len(node.inputs) <= 1 + # If either one of the inputs has a `None` shape, we cannot + # assume they will have the same size + or all( + len(set(inp_shape)) == 1 and None not in inp_shape + for inp_shape in zip( + *(inp.type.shape for inp in node.inputs), strict=True + ) + ) + ): + z = onames[0] + contig = f""" + // All output have the same size + npy_intp n = PyArray_SIZE({z}); + """ + index = "" + for x, var in zip(inames + onames, inputs + node.outputs, strict=True): + if not all(s == 1 for s in var.type.shape): + contig += f""" + dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); + """ + index += f""" + dtype_{x}& {x}_i = {x}_ptr[i]; + """ + else: + contig += f""" + dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; + """ + if use_openmp: + contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) + """ + contig += f""" + for(int i=0; i bool: return self.op.openmp and openmp_supported() def c_support_code(self, **kwargs) -> str: - return self.op.scalar_op.c_support_code(**kwargs) + return cast(str, self.op.scalar_op.c_support_code(**kwargs)) def c_support_code_apply(self, node: Apply, name: str) -> str: - return self.op.scalar_op.c_support_code_apply(node, name + "_scalar_") + return cast( + str, self.op.scalar_op.c_support_code_apply(node, name + "_scalar_") + ) def c_headers(self, **kwargs) -> list[str]: return ["", ""] def c_header_dirs(self, **kwargs) -> list[str]: - return self.op.scalar_op.c_header_dirs(**kwargs) + return cast(list[str], self.op.scalar_op.c_header_dirs(**kwargs)) def c_compile_args(self, **kwargs) -> list[str]: return ["-fopenmp"] if self._use_openmp() else [] @@ -155,8 +670,8 @@ def c_code( ) -> str: scalar_op = self.op.scalar_op if ( - any(i.dtype == "float16" for i in node.inputs) - or any(o.dtype == "float16" for o in node.outputs) + any(i.type.dtype == "float16" for i in node.inputs) + or any(o.type.dtype == "float16" for o in node.outputs) # This is for Composite or getattr(scalar_op, "inner_float16", False) ): @@ -170,7 +685,7 @@ def c_code( def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: scalar_op = self.op.scalar_op - version = [17] # bump when the emitted C changes + version: list[Hashable] = [17] # bump when the emitted C changes scalar_node = Apply( scalar_op, [get_scalar_type(dtype=i.type.dtype).make_variable() for i in node.inputs], @@ -191,3 +706,49 @@ def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: @c_funcify.register(Elemwise) def c_funcify_elemwise(op, node=None, **kwargs) -> ElemwiseImpl: return ElemwiseImpl(op) + + +class CAReduceImpl(CImpl): + """C implementation of `CAReduce` (and its subclasses). + + Emits the reduction loop over the per-element scalar code via + `_careduce_c_all`; the loop inlines the scalar op's `c_code` directly, so no + support code is delegated. + """ + + op: CAReduce + + def c_headers(self, **kwargs) -> list[str]: + return ["", ""] + + def c_code( + self, + node: Apply, + name: str, + inputs: list[str], + outputs: list[str], + sub: dict[str, str], + ) -> str: + return "\n".join(_careduce_c_all(self.op, node, name, inputs, outputs, sub)) + + def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: + scalar_op = self.op.scalar_op + version = [11] # bump when the emitted C changes + scalar_node = Apply( + scalar_op, + [get_scalar_type(dtype=i.type.dtype).make_variable() for i in node.inputs], + [get_scalar_type(dtype=o.type.dtype).make_variable() for o in node.outputs], + ) + version.append(scalar_op.c_code_cache_version_apply(scalar_node)) + version.extend( + get_scalar_type(dtype=i.type.dtype).c_code_cache_version() + for i in node.inputs + node.outputs + ) + if all(version): + return tuple(version) + return () + + +@c_funcify.register(CAReduce) +def c_funcify_careduce(op, node=None, **kwargs) -> CAReduceImpl: + return CAReduceImpl(op) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 9259539dae..87f5c80ea4 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from copy import copy -from textwrap import dedent from typing import Literal import numpy as np @@ -13,15 +12,10 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed -from pytensor.graph.utils import MethodNotDefined -from pytensor.link.c.basic import failure_code -from pytensor.link.c.op import COp, openmp_supported from pytensor.misc.frozendict import frozendict from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type -from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import upcast -from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.type import ( @@ -36,7 +30,6 @@ normalize_reduce_axis, ) from pytensor.tensor.variable import TensorVariable -from pytensor.utils import uniq, unzip class DimShuffle(Op): @@ -298,318 +291,6 @@ def process(self, r, pstate): pprint.assign(DimShuffle, DimShufflePrinter()) -def _elemwise_c_all(op, node, nodename, inames, onames, sub, use_openmp): - # Some `Op`s directly call this generator on a fresh Elemwise. - # To not request all of them to call prepare_node(), do it here. - # There is no harm if it get called multiple times. - if not hasattr(node.tag, "fake_node"): - op.prepare_node(node, None, None, "c") - _inames = inames - _onames = onames - - inames = uniq(inames) - inputs = uniq(node.inputs) - # assert that inames and inputs order stay consistent. - # This is to protect again futur change of uniq. - assert len(inames) == len(inputs) - ii, iii = unzip( - uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True - ) - assert all(x == y for x, y in zip(ii, inames, strict=True)) - assert all(x == y for x, y in zip(iii, inputs, strict=True)) - - defines = "" - undefs = "" - - # The destroy map is a map of output indices to input indices - # that overwrite them. We just convert them to the actual - # Variables. - dmap = {node.outputs[o]: [node.inputs[i]] for o, i in op.inplace_pattern.items()} - - # dtypes of the inputs - idtypes = [input.type.dtype_specs()[1] for input in inputs] - - # These are the outputs that we will need to allocate - # (output, name, name of the c type), transposed - real = list( - zip( - *[ - (r, s, r.type.dtype_specs()[1]) - for r, s in zip(node.outputs, onames, strict=True) - if r not in dmap - ], - strict=True, - ) - ) - if real: - real_outputs, real_onames, real_odtypes = real - else: - real_outputs, real_onames, real_odtypes = [], [], [] - - # Outputs that are aliased with an input (inplace) - # (output, name), transposed (c type name not needed since we don't - # need to allocate. - aliased = list( - zip( - *[(r, s) for (r, s) in zip(node.outputs, onames, strict=True) if r in dmap], - strict=True, - ) - ) - if aliased: - aliased_outputs, aliased_onames = aliased - else: - aliased_outputs, aliased_onames = [], [] - - # for each input: - # same as range(ndim), but with 'x' at all broadcastable positions - orders = [ - [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)] - for input in inputs - ] - - # number of nested loops we will need (all inputs have same - # dimensionality) - nnested = len(orders[0]) - sub = dict(sub) - for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): - # the c generators will substitute the input names for - # references to loop variables lv0, lv1, ... - sub[f"lv{i}"] = iname - - decl = cgen.make_declare(orders, idtypes, sub) - checks = cgen.make_checks(orders, idtypes, sub) - - # Check if all inputs (except broadcasted scalar) are fortran. - # In that case, create a fortran output ndarray. - z = list(zip(inames, inputs, strict=True)) - alloc_fortran = " && ".join( - f"PyArray_ISFORTRAN({arr})" - for arr, var in z - if not all(s == 1 for s in var.type.shape) - ) - # If it is a scalar, make it c contig to prevent problem with - # NumPy C and F contig not always set as both of them. - if len(alloc_fortran) == 0: - alloc_fortran = "0" - - alloc = "" - # We loop over the "real" outputs, i.e., those that are not - # inplace (must be allocated) and we declare/allocate/check - # them - for output, oname, odtype in zip( - real_outputs, real_onames, real_odtypes, strict=True - ): - i += 1 # before this loop, i = number of inputs - sub[f"lv{i}"] = oname - sub["olv"] = oname - alloc += cgen.make_declare( - [list(range(nnested))], [odtype], dict(sub, lv0=oname) - ) - alloc += cgen.make_alloc(orders, odtype, sub, fortran=alloc_fortran) - alloc += cgen.make_checks( - [list(range(nnested))], [odtype], dict(sub, lv0=oname) - ) - olv_index = i # index of the last output - - # We loop over the "aliased" outputs, i.e., those that are - # inplace (overwrite the contents of one of the inputs) and - # make the output pointers point to their corresponding input - # pointers. - for output, oname in zip(aliased_outputs, aliased_onames, strict=True): - olv_index = inputs.index(dmap[output][0]) - iname = inames[olv_index] - # We make the output point to the corresponding input and - # decrease the reference of whatever the output contained - # prior to this - alloc += f""" - if ({oname}) {{ - Py_XDECREF({oname}); - }} - {oname} = {iname}; - Py_XINCREF({oname}); - """ - # We alias the scalar variables - defines += f"#define {oname}_i {iname}_i\n" - undefs += f"#undef {oname}_i\n" - - # Note: here, olv_index is either the index of the last output - # which is allocated, OR, if there are any aliased outputs, - # the index of the last of these aliased outputs. - - # We generate the C code of the inner loop using the scalar op - if use_openmp: - # If we are using openmp, we need to get rid of the "goto" - # statement in sub['fail']. For now we recreate it here. - fail = failure_code(sub, use_goto=False) - else: - fail = sub["fail"] - task_code = op.scalar_op.c_code( - node.tag.fake_node, - nodename + "_scalar_", - [f"{s}_i" for s in _inames], - [f"{s}_i" for s in onames], - dict(sub, fail=fail), - ) - code = f""" - {{ - {defines} - {task_code} - {undefs} - }} - """ - - loop_orders = orders + [list(range(nnested))] * len(real_onames) - dtypes = idtypes + list(real_odtypes) - if all( - [o.ndim <= 1 for o in node.outputs] - or - # Use simpler code when output ndim == 0 or 1 - # or for broadcated scalar. - all(s == 1 for s in node.outputs[0].type.shape) - ): - if nnested: - all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] - else: - all_code = [code] - if len(all_code) == 1: - # No loops - task_decl = "".join( - f"{dtype}& {name}_i = *{name}_iter;\n" - for name, dtype in zip( - inames + list(real_onames), - idtypes + list(real_odtypes), - strict=True, - ) - ) - - preloops = {} - for i, (loop_order, dtype) in enumerate( - zip(loop_orders, dtypes, strict=True) - ): - for j, index in enumerate(loop_order): - if index != "x": - preloops.setdefault(j, "") - preloops[j] += ( - f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" - ) % sub - break - else: # all broadcastable - preloops.setdefault(0, "") - preloops[0] += ( - f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n" - ) % sub - - init_array = preloops.get(0, " ") - loop = f""" - {{ - {defines} - {init_array} - {task_decl} - {task_code} - {undefs} - }} - """ - else: - loop = cgen.make_loop( - loop_orders=loop_orders, - dtypes=dtypes, - loop_tasks=all_code, - sub=sub, - openmp=use_openmp, - ) - else: - loop = cgen.make_reordered_loop( - init_loop_orders=loop_orders, - olv_index=olv_index, - dtypes=dtypes, - inner_task=code, - sub=sub, - openmp=use_openmp, - ) - - # If all inputs and outputs are contiguous - # and the scalar op define optimized code for that case - # use it! The scalar_op needs to check the type-level shapes itself. - if ( - all(o.ndim >= 1 for o in node.outputs) - and - # Don't use the contig code for broadcasted scalar. - not all(s == 1 for s in node.outputs[0].type.shape) - ): - contig = None - try: - contig = op.scalar_op.c_code_contiguous( - node, nodename + "_scalar_contig_", _inames, onames, sub - ) - except MethodNotDefined: - # Try to make one generic version, this will help the - # compiler to vectorize the code as their won't be as - # many ptr and the stride will be hard coded. - if all( - # io.type.shape == node.outputs[1].type.shape - # Elemwise does not specify non-broadcastable static/type-levelshape - # information for its outputs yet - node.outputs[0].type.is_super(io.type) - for io in node.inputs + node.outputs - ) and ( - len(node.inputs) <= 1 - # If either one of the inputs has a `None` shape, we cannot - # assume they will have the same size - or all( - len(set(inp_shape)) == 1 and None not in inp_shape - for inp_shape in zip( - *(inp.type.shape for inp in node.inputs), strict=True - ) - ) - ): - z = onames[0] - contig = f""" - // All output have the same size - npy_intp n = PyArray_SIZE({z}); - """ - index = "" - for x, var in zip(inames + onames, inputs + node.outputs, strict=True): - if not all(s == 1 for s in var.type.shape): - contig += f""" - dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); - """ - index += f""" - dtype_{x}& {x}_i = {x}_ptr[i]; - """ - else: - contig += f""" - dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; - """ - if use_openmp: - contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) - """ - contig += f""" - for(int i=0; i", ""] - - def c_code_cache_version_apply(self, node): - # the version corresponding to the c code in this Op - version = [11] - - # now we insert versions for the ops on which we depend... - scalar_node = Apply( - self.scalar_op, - [ - get_scalar_type(dtype=input.type.dtype).make_variable() - for input in node.inputs - ], - [ - get_scalar_type(dtype=output.type.dtype).make_variable() - for output in node.outputs - ], - ) - version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) - version.extend( - get_scalar_type(dtype=i.type.dtype).c_code_cache_version() - for i in node.inputs + node.outputs - ) - if all(version): - return tuple(version) - else: - return () - def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): """Replace a symbol definition with an `Elemwise`-wrapped version of the corresponding scalar `Op`. diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..75e11090cf 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1,7 +1,6 @@ import builtins import warnings from collections.abc import Sequence -from textwrap import dedent from typing import TYPE_CHECKING, Optional import numpy as np @@ -366,38 +365,7 @@ def __str__(self): class NonZeroDimsCAReduce(FixedOpCAReduce): - def _c_all(self, node, name, input_names, output_names, sub): - setup, alloc, loop, cast = super()._c_all( - node, name, input_names, output_names, sub - ) - - # We add an additional check for zero-sized dimensions (This seems like - # something that could enabled in `elemwise_cgen.make_checks`.) - [iname] = input_names - - axis = self.axis - if axis is None: - axis = list(range(len(node.inputs[0].type.broadcastable))) - - pattern = [0] * len(node.inputs[0].broadcastable) - for i in axis: - pattern[i] = 1 - - pattern_ = str(pattern)[1:-1] - - setup = f"int tosum[]={{{pattern_}}};" + setup - alloc += dedent( - f""" - for(int i=0;i