Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions pytensor/compile/debug/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,10 @@ def printstuff(self):

# List of default version of make thunk.
# This is needed to know if the user overrode it.
default_make_thunk = [get_unbound_function(COp.make_thunk)]
default_make_thunk = [
get_unbound_function(Op.make_thunk),
get_unbound_function(COp.make_thunk),
]


# Debug mode cheats and initializes the linker in a different way in
Expand Down Expand Up @@ -1352,6 +1355,7 @@ def make_all(
# can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the
# filter_checks_isfinite
from pytensor.link.c.dispatch.basic import c_thunk_from_dispatch

fgraph = self.fgraph
input_storage_ = input_storage
Expand Down Expand Up @@ -1404,15 +1408,10 @@ def make_all(
debug = hasattr(node.op, "debug_perform")

try:
if (
not self.maker.mode.check_c_code
or debug
or not isinstance(node.op, COp)
):
if not self.maker.mode.check_c_code or debug:
raise MethodNotDefined()

node.op.prepare_node(node, storage_map, compute_map, "c")
thunk = node.op.make_c_thunk(
thunk = c_thunk_from_dispatch(
node, storage_map, compute_map, no_recycling
)
thunks_c.append(thunk)
Expand All @@ -1434,19 +1433,18 @@ def make_all(
else:
thunks_py.append(None)

if (
not self.maker.mode.check_c_code
and thunks_py[-1] is None
and isinstance(node.op, COp)
):
_logger.warning(
f"Op {node.op} doesn't have a perform, forcing check of the C code"
)
node.op.prepare_node(node, storage_map, compute_map, "c")
thunk = node.op.make_c_thunk(
node, storage_map, compute_map, no_recycling
)
thunks_c[-1] = thunk
if not self.maker.mode.check_c_code and thunks_py[-1] is None:
try:
thunk = c_thunk_from_dispatch(
node, storage_map, compute_map, no_recycling
)
except (NotImplementedError, MethodNotDefined):
pass
else:
_logger.warning(
f"Op {node.op} doesn't have a perform, forcing check of the C code"
)
thunks_c[-1] = thunk

# If the op defined its own make_thunk, use the generated thunk
if thunk_other is not None:
Expand Down
97 changes: 75 additions & 22 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config
from pytensor.graph.basic import (
Apply,
AtomicVariable,
Constant,
)
Expand Down Expand Up @@ -553,6 +554,7 @@ class CLinker(Linker):

def __init__(self, schedule=None):
self.fgraph = None
self._node_impls: dict[Apply, CLinkerOp] = {}
super().__init__(scheduler=schedule)

def accept(
Expand All @@ -573,6 +575,23 @@ def accept(
self.no_recycling = no_recycling
return self

def _impl_for(self, node) -> CLinkerOp:
"""Return `node`'s C implementation, resolved and memoized via `c_funcify`."""
try:
return self._node_impls[node]
except KeyError:
pass
# Imported lazily so `import pytensor` does not load the dispatch
# registrations; the import triggers them on first compilation.
from pytensor.link.c.dispatch.basic import c_funcify

try:
impl = c_funcify(node.op, node=node)
except NotImplementedError as exc:
raise NotImplementedError(f"{node.op} cannot produce C code") from exc
self._node_impls[node] = impl
return impl

def fetch_variables(self):
"""Fills the inputs, outputs, variables, orphans, temps and node_order fields."""
fgraph = self.fgraph
Expand All @@ -591,21 +610,25 @@ def fetch_variables(self):
# that needs it
self.node_params = dict()
for node in self.node_order:
if not isinstance(node.op, CLinkerOp):
try:
impl = self._impl_for(node)
except NotImplementedError:
# No C implementation; code_gen will raise if this node is
# actually compiled.
continue
try:
params = node.op.get_params(node)
params = impl.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams:
# try to avoid creating more than one variable for the
# same params.
if params in self.node_params:
var = self.node_params[params]
assert var.type == node.params_type
assert var.type == impl.params_type
fgraph.clients[var].append((node, "params"))
else:
var = Constant(node.params_type, params)
var = Constant(impl.params_type, params)
fgraph.clients[var] = [(node, "params")]
self.node_params[params] = var
self.variables.append(var)
Expand Down Expand Up @@ -775,10 +798,7 @@ def code_gen(self):
id += 2

for node_num, node in enumerate(self.node_order):
op = node.op

if not isinstance(op, CLinkerOp):
raise NotImplementedError(f"{op} cannot produce C code")
op = self._impl_for(node)

sub = dict(failure_var=failure_var)

Expand Down Expand Up @@ -905,7 +925,9 @@ def support_code(self):
"""
)
# generic support code
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
support_code = x.c_support_code()
if isinstance(support_code, list):
ret.extend(support_code)
Expand Down Expand Up @@ -941,15 +963,19 @@ def compile_args(self):

c_compiler = self.c_compiler()

for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_compile_args(c_compiler=c_compiler)

ret = uniq(ret) # to remove duplicate
# The args set by the compiler include the user flags. We do not want
# to reorder them
ret += c_compiler.compile_args()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
no_comp = x.c_no_compile_args(c_compiler=c_compiler)

Expand All @@ -970,7 +996,9 @@ def headers(self):
"""
ret = []
c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_headers(c_compiler=c_compiler)
return uniq(ret)
Expand All @@ -984,14 +1012,18 @@ def init_code(self):

"""
ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_init_code()
return uniq(ret)

def c_compiler(self):
c_compiler = None
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
# FIXME: Why would a `Type` have a `c_compiler` field?!
if hasattr(x, "c_compiler"):
x_compiler = x.c_compiler()
Expand Down Expand Up @@ -1021,7 +1053,9 @@ def header_dirs(self):
"""
ret = []
c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_header_dirs(c_compiler=c_compiler)
# filter out empty strings/None
Expand All @@ -1037,7 +1071,9 @@ def libraries(self):
"""
ret = []
c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_libraries(c_compiler=c_compiler)
return uniq(ret)
Expand All @@ -1052,7 +1088,9 @@ def lib_dirs(self):
"""
ret = []
c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
for x in [y.type for y in self.variables] + [
self._impl_for(y) for y in self.node_order
]:
if isinstance(x, CLinkerObject):
ret += x.c_lib_dirs(c_compiler=c_compiler)
# filter out empty strings/None
Expand Down Expand Up @@ -1433,8 +1471,14 @@ def in_sig(i, topological_pos, i_idx):

version = []
for node_pos, node in enumerate(order):
if hasattr(node.op, "c_code_cache_version_apply"):
version.append(node.op.c_code_cache_version_apply(node))
try:
impl = self._impl_for(node)
except NotImplementedError:
# No C implementation: contributes no version entry (code_gen
# raises later if this graph is compiled).
pass
else:
version.append(impl.c_code_cache_version_apply(node))

props = getattr(node.op, "__props__", None)

Expand Down Expand Up @@ -1819,6 +1863,8 @@ def accept(self, fgraph, no_recycling=None, profile=None):
def make_all(
self, profiler=None, input_storage=None, output_storage=None, storage_map=None
):
from pytensor.link.c.dispatch.basic import make_node_thunk_with_c_dispatch

fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
Expand All @@ -1838,9 +1884,16 @@ def make_all(

thunks = []
for node in order:
# make_thunk will try by default C code, otherwise
# it fall back to python.
thunks += [node.op.make_thunk(node, storage_map, compute_map, no_recycling)]
# Try the C dispatch first, otherwise fall back to Python.
thunks += [
make_node_thunk_with_c_dispatch(
node,
storage_map,
compute_map,
no_recycling,
try_c=bool(config.cxx),
)
]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]

Expand Down
8 changes: 8 additions & 0 deletions pytensor/link/c/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# isort: off
from pytensor.link.c.dispatch.basic import c_funcify

# Load dispatch specializations
import pytensor.link.c.dispatch.elemwise
import pytensor.link.c.dispatch.raise_op

# isort: on
Loading
Loading