diff --git a/gen/semantic-dcompute.cpp b/gen/semantic-dcompute.cpp index 1cf3fb4843..acc67af1c3 100644 --- a/gen/semantic-dcompute.cpp +++ b/gen/semantic-dcompute.cpp @@ -212,6 +212,11 @@ struct DComputeSemanticAnalyser : public StoppableVisitor { } } void visit(CallExp *e) override { + // Indirect calls via function pointers / delegates have no associated + // FuncDeclaration, so there is no module to check. + if (!e->f) + return; + // SynchronizedStatement is lowered to // Critsec __critsec105; // 105 == line number // _d_criticalenter(& __critsec105); <-- @@ -229,9 +234,15 @@ struct DComputeSemanticAnalyser : public StoppableVisitor { stop = true; return; } - + Module *m = e->f->getModule(); - if ((m == nullptr || (hasComputeAttr(m) == DComputeCompileFor::hostOnly)) && + // Template-instantiated functions are cross-module by nature: the template + // declaration and the instantiated function live in different modules. + // getModule() returns the *declaration* module, which says nothing about + // whether the generated code can run on GPU. Skip the module check for them. + const bool isTemplateFunc = e->f->isInstantiated() != nullptr; + if (!isTemplateFunc && + (m == nullptr || (hasComputeAttr(m) == DComputeCompileFor::hostOnly)) && !isNonComputeCallExpVaild(e)) { error(e->loc, "can only call functions from other `@compute` modules in " "`@compute` code"); diff --git a/tests/compilable/issue5116.d b/tests/compilable/issue5116.d new file mode 100644 index 0000000000..f59b57a245 --- /dev/null +++ b/tests/compilable/issue5116.d @@ -0,0 +1,26 @@ +// Without this patch, template instantiations from non-@compute modules (e.g. +// __equals from core.internal.array.equality) were walked, producing spurious errors. +// And indirect calls via function pointers inside those bodies additionally caused a +// null-pointer dereference crash. + +// REQUIRES: target_NVPTX +// RUN: %ldc -mdcompute-targets=cuda-350 %s + +@compute(CompileFor.deviceOnly) module tests.compilable.issue5116; +import ldc.dcompute; + +private enum N = 16u; + +struct S { + float[N] data; +} + +@kernel void testEqualExp() { + float[N] a, b; + bool c = (a == b); +} + +@kernel void testExplicitEquals() { + float[N] a, b; + bool c = __equals(a, b); +}