Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/llvm-dialects/Dialect/OpDescription.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,14 @@ class OpDescription {
return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads;
}

// Only supported for concrete ops, not for op classes.
template <typename OpT> static const OpDescription &get();

// For concrete ops, returns a 1-element array containing the result of get().
// For op classes, returns an array containing the descriptions of all
// concrete ops that belong to this op class.
template <typename OpT> static llvm::ArrayRef<OpDescription> getAll();

Kind getKind() const { return m_kind; }

unsigned getOpcode() const;
Expand Down
13 changes: 11 additions & 2 deletions include/llvm-dialects/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,18 @@ class VisitorKey {
friend class VisitorTemplate;

public:
// OpT may be a concrete dialect op, or an op class.
template <typename OpT> static VisitorKey op() {
VisitorKey key{Kind::OpDescription};
key.m_description = &OpDescription::get<OpT>();
auto const descriptions = OpDescription::getAll<OpT>();
if (descriptions.size() == 1) {
VisitorKey key{Kind::OpDescription};
key.m_description = &descriptions[0];
return key;
}
// OpT is an op class. Resolve it by all concrete sub ops.
static const OpSet set = OpSet::fromOpDescriptions(descriptions);
VisitorKey key{Kind::OpSet};
key.m_set = &set;
return key;
}

Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/OpDescription.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
return desc;
}

template <> ArrayRef<OpDescription> OpDescription::getAll<UnaryInstruction>() {
return get<UnaryInstruction>();
}

template <> const OpDescription &OpDescription::get<BinaryOperator>() {
static unsigned opcodes[] = {
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
Expand All @@ -121,6 +125,10 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
return desc;
}

template <> ArrayRef<OpDescription> OpDescription::getAll<BinaryOperator>() {
return get<BinaryOperator>();
}

// Generate OpDescription for all dedicate instruction classes.
#define HANDLE_USER_INST(...)
#define HANDLE_UNARY_INST(...)
Expand All @@ -129,19 +137,28 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
template <> const OpDescription &OpDescription::get<Class>() { \
static const OpDescription desc{Kind::Core, Instruction::opcode}; \
return desc; \
} \
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
return get<Class>(); \
}
#include "llvm/IR/Instruction.def"

#define HANDLE_INTRINSIC_DESC(Class, opcode) \
template <> const OpDescription &OpDescription::get<Class>() { \
static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \
return desc; \
} \
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
return get<Class>(); \
}
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
template <> const OpDescription &OpDescription::get<Class>() { \
static unsigned opcodes[] = {__VA_ARGS__}; \
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
return desc; \
} \
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
return get<Class>(); \
}

// ============================================================================
Expand Down
49 changes: 46 additions & 3 deletions lib/TableGen/GenDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {

out << R"(
#include "llvm/Support/raw_ostream.h"
#include <array>
#endif // GET_INCLUDES

#ifdef GET_DIALECT_DEFS
Expand Down Expand Up @@ -326,10 +327,10 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
}

bool $Dialect::isDialectOp(::llvm::CallInst& op) {
::llvm::Function *calledFunc = op.getCalledFunction();
::llvm::Function *calledFunc = op.getCalledFunction();
if (!calledFunc)
return false;

return isDialectOp(calledFunc->getName());
}

Expand Down Expand Up @@ -448,14 +449,18 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
if (!dialect->cppNamespace.empty())
out << tgfmt("} // namespace $namespace\n", &fmt);

// Define specializations of OpDescription::get for reflection
// Define specializations of OpDescription::{get, getAll} for reflection
for (const auto &opPtr : dialect->operations) {
Operation &op = *opPtr;

FmtContextScope scope{fmt};
fmt.withOp(op.name);
fmt.addSubst("mnemonic", op.mnemonic);

// We'd prefer fully qualifying the llvm_dialects namespace below in getAll
// with leading "::", but this is parsed as oart of the preceding ArrayRef
// type as there are just spaces in between. (gcc/clang/MSVC reject this)
// The get() variant does not have this problem due to the `&` token.
out << tgfmt(R"(
template <>
const ::llvm_dialects::OpDescription &
Expand All @@ -464,10 +469,48 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
return desc;
}

template <>
::llvm::ArrayRef<::llvm_dialects::OpDescription>
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
return get<$namespace::$_op>();
}

)",
&fmt, op.haveResultOverloads() ? "true" : "false");
}

// Define specializations of OpDescription::getAll for op classes
for (const auto &opClassPtr : dialect->opClasses) {
OpClass &opClass = *opClassPtr;

FmtContextScope scope{fmt};
fmt.withOp(opClass.name);

// We'd prefer fully qualifying the llvm_dialects namespace below with
// leading "::", but gcc/clang/MSVC reject this as they interpret the
// ::llvm_dialects identifier than within the preceding ArrayRef type. The
// get() variant does not have this problem as the `&` token separates the
// two.
out << tgfmt(R"(
template <>
::llvm::ArrayRef<::llvm_dialects::OpDescription>
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
static const std::array<::llvm_dialects::OpDescription, $0> desc{)",
&fmt, opClass.operations.size());
for (const auto &op : opClass.operations) {
fmt.addSubst("mnemonic", op->mnemonic);
out << tgfmt(R"(
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)",
&fmt, op->haveResultOverloads() ? "true" : "false");
}
out << tgfmt(R"(
};
return desc;
}
)",
&fmt, opClass.operations.size());
}

out << R"(
#endif // GET_DIALECT_DEFS
)";
Expand Down
Loading
Loading