diff --git a/include/llvm-dialects/Dialect/OpDescription.h b/include/llvm-dialects/Dialect/OpDescription.h index a1ecc49..14df9ca 100644 --- a/include/llvm-dialects/Dialect/OpDescription.h +++ b/include/llvm-dialects/Dialect/OpDescription.h @@ -65,8 +65,14 @@ class OpDescription { return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads; } + // Only supported for concrete ops, not for op classes. template static const OpDescription &get(); + // For concrete ops, returns a 1-element array containing the result of get(). + // For op classes, returns an array containing the descriptions of all + // concrete ops that belong to this op class. + template static llvm::ArrayRef getAll(); + Kind getKind() const { return m_kind; } unsigned getOpcode() const; diff --git a/include/llvm-dialects/Dialect/Visitor.h b/include/llvm-dialects/Dialect/Visitor.h index 41568ac..21fc900 100644 --- a/include/llvm-dialects/Dialect/Visitor.h +++ b/include/llvm-dialects/Dialect/Visitor.h @@ -120,9 +120,18 @@ class VisitorKey { friend class VisitorTemplate; public: + // OpT may be a concrete dialect op, or an op class. template static VisitorKey op() { - VisitorKey key{Kind::OpDescription}; - key.m_description = &OpDescription::get(); + auto const descriptions = OpDescription::getAll(); + if (descriptions.size() == 1) { + VisitorKey key{Kind::OpDescription}; + key.m_description = &descriptions[0]; + return key; + } + // OpT is an op class. Resolve it by all concrete sub ops. + static const OpSet set = OpSet::fromOpDescriptions(descriptions); + VisitorKey key{Kind::OpSet}; + key.m_set = &set; return key; } diff --git a/lib/Dialect/OpDescription.cpp b/lib/Dialect/OpDescription.cpp index 600acf0..4b6485d 100644 --- a/lib/Dialect/OpDescription.cpp +++ b/lib/Dialect/OpDescription.cpp @@ -112,6 +112,10 @@ template <> const OpDescription &OpDescription::get() { return desc; } +template <> ArrayRef OpDescription::getAll() { + return get(); +} + template <> const OpDescription &OpDescription::get() { static unsigned opcodes[] = { #define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode, @@ -121,6 +125,10 @@ template <> const OpDescription &OpDescription::get() { return desc; } +template <> ArrayRef OpDescription::getAll() { + return get(); +} + // Generate OpDescription for all dedicate instruction classes. #define HANDLE_USER_INST(...) #define HANDLE_UNARY_INST(...) @@ -129,6 +137,9 @@ template <> const OpDescription &OpDescription::get() { template <> const OpDescription &OpDescription::get() { \ static const OpDescription desc{Kind::Core, Instruction::opcode}; \ return desc; \ + } \ + template <> ArrayRef OpDescription::getAll() { \ + return get(); \ } #include "llvm/IR/Instruction.def" @@ -136,12 +147,18 @@ template <> const OpDescription &OpDescription::get() { template <> const OpDescription &OpDescription::get() { \ static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \ return desc; \ + } \ + template <> ArrayRef OpDescription::getAll() { \ + return get(); \ } #define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \ template <> const OpDescription &OpDescription::get() { \ static unsigned opcodes[] = {__VA_ARGS__}; \ static const OpDescription desc{Kind::Intrinsic, opcodes}; \ return desc; \ + } \ + template <> ArrayRef OpDescription::getAll() { \ + return get(); \ } // ============================================================================ diff --git a/lib/TableGen/GenDialect.cpp b/lib/TableGen/GenDialect.cpp index ba655a9..6d48bf7 100644 --- a/lib/TableGen/GenDialect.cpp +++ b/lib/TableGen/GenDialect.cpp @@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { out << R"( #include "llvm/Support/raw_ostream.h" +#include #endif // GET_INCLUDES #ifdef GET_DIALECT_DEFS @@ -326,10 +327,10 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { } bool $Dialect::isDialectOp(::llvm::CallInst& op) { - ::llvm::Function *calledFunc = op.getCalledFunction(); + ::llvm::Function *calledFunc = op.getCalledFunction(); if (!calledFunc) return false; - + return isDialectOp(calledFunc->getName()); } @@ -448,7 +449,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { if (!dialect->cppNamespace.empty()) out << tgfmt("} // namespace $namespace\n", &fmt); - // Define specializations of OpDescription::get for reflection + // Define specializations of OpDescription::{get, getAll} for reflection for (const auto &opPtr : dialect->operations) { Operation &op = *opPtr; @@ -456,6 +457,10 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { fmt.withOp(op.name); fmt.addSubst("mnemonic", op.mnemonic); + // We'd prefer fully qualifying the llvm_dialects namespace below in getAll + // with leading "::", but this is parsed as oart of the preceding ArrayRef + // type as there are just spaces in between. (gcc/clang/MSVC reject this) + // The get() variant does not have this problem due to the `&` token. out << tgfmt(R"( template <> const ::llvm_dialects::OpDescription & @@ -464,10 +469,48 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) { return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll<$namespace::$_op>() { + return get<$namespace::$_op>(); + } + )", &fmt, op.haveResultOverloads() ? "true" : "false"); } + // Define specializations of OpDescription::getAll for op classes + for (const auto &opClassPtr : dialect->opClasses) { + OpClass &opClass = *opClassPtr; + + FmtContextScope scope{fmt}; + fmt.withOp(opClass.name); + + // We'd prefer fully qualifying the llvm_dialects namespace below with + // leading "::", but gcc/clang/MSVC reject this as they interpret the + // ::llvm_dialects identifier than within the preceding ArrayRef type. The + // get() variant does not have this problem as the `&` token separates the + // two. + out << tgfmt(R"( + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll<$namespace::$_op>() { + static const std::array<::llvm_dialects::OpDescription, $0> desc{)", + &fmt, opClass.operations.size()); + for (const auto &op : opClass.operations) { + fmt.addSubst("mnemonic", op->mnemonic); + out << tgfmt(R"( + ::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)", + &fmt, op->haveResultOverloads() ? "true" : "false"); + } + out << tgfmt(R"( + }; + return desc; + } + )", + &fmt, opClass.operations.size()); + } + out << R"( #endif // GET_DIALECT_DEFS )"; diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 4c9fbe5..81c5da7 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -19,6 +19,7 @@ #include "llvm/Support/ModRef.h" #include "llvm/Support/raw_ostream.h" +#include #endif // GET_INCLUDES #ifdef GET_DIALECT_DEFS @@ -33,10 +34,10 @@ namespace xd::cpp { } bool ExampleDialect::isDialectOp(::llvm::CallInst& op) { - ::llvm::Function *calledFunc = op.getCalledFunction(); + ::llvm::Function *calledFunc = op.getCalledFunction(); if (!calledFunc) return false; - + return isDialectOp(calledFunc->getName()); } @@ -2733,6 +2734,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2741,6 +2748,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2749,6 +2762,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2757,6 +2776,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2765,6 +2790,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2773,6 +2804,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2781,6 +2818,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2789,6 +2832,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2797,6 +2846,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2805,6 +2860,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2813,6 +2874,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2821,6 +2888,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2829,6 +2902,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2837,6 +2916,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2845,6 +2930,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2853,6 +2944,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2861,6 +2958,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2869,6 +2972,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2877,6 +2986,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2885,6 +3000,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2893,6 +3014,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2901,6 +3028,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2909,6 +3042,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2917,6 +3056,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2925,6 +3070,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2933,6 +3084,12 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + template <> const ::llvm_dialects::OpDescription & @@ -2941,5 +3098,22 @@ data return desc; } + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + return get(); + } + + + template <> + ::llvm::ArrayRef<::llvm_dialects::OpDescription> + llvm_dialects::OpDescription::getAll() { + static const std::array<::llvm_dialects::OpDescription, 3> desc{ + ::llvm_dialects::OpDescription{true, "xd.ir.stream.add"}, + ::llvm_dialects::OpDescription{true, "xd.ir.stream.max"}, + ::llvm_dialects::OpDescription{true, "xd.ir.stream.min"}, + }; + return desc; + } #endif // GET_DIALECT_DEFS diff --git a/test/unit/dialect/TestDialect.td b/test/unit/dialect/TestDialect.td index af3ea6f..bb32253 100644 --- a/test/unit/dialect/TestDialect.td +++ b/test/unit/dialect/TestDialect.td @@ -135,3 +135,26 @@ def DefaultNameTargetExtType : DialectType { Type name should be "test.default.target". }]; } + +def SomeBaseOpClass : OpClass { + let arguments = (ins); + let summary = "family of operations"; + let description = [{ + Illustrate the use of the OpClass feature. + }]; +} + +class ConcreteOpClassMemberOp + : TestOp<"some.op.class.subop." # op, []> { + let superclass = SomeBaseOpClass; + let results = (outs); + let arguments = (ins superclass); + + let summary = "perform the " # op # " operation"; + let description = [{ + Illustrate the use of the OpClass feature. + }]; +} + +def AddOp : ConcreteOpClassMemberOp<"add">; +def MulOp : ConcreteOpClassMemberOp<"mul">; diff --git a/test/unit/interface/CMakeLists.txt b/test/unit/interface/CMakeLists.txt index b35ed3e..17bb98a 100644 --- a/test/unit/interface/CMakeLists.txt +++ b/test/unit/interface/CMakeLists.txt @@ -21,6 +21,7 @@ add_dialects_unit_test(DialectsADTTests OpSetTests.cpp OpMapTests.cpp OpMapIRTests.cpp - DialectTypeTests.cpp) + DialectTypeTests.cpp + VisitorIRTests.cpp) add_dependencies(DialectsADTTests TestDialectTableGen) diff --git a/test/unit/interface/VisitorIRTests.cpp b/test/unit/interface/VisitorIRTests.cpp new file mode 100644 index 0000000..50c2e2e --- /dev/null +++ b/test/unit/interface/VisitorIRTests.cpp @@ -0,0 +1,110 @@ +/* + *********************************************************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc., or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *********************************************************************************************************************** + */ + +#include "TestDialect.h" +#include "llvm-dialects/Dialect/Builder.h" +#include "llvm-dialects/Dialect/Dialect.h" +#include "llvm-dialects/Dialect/Visitor.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "gtest/gtest.h" + +#include + +using namespace llvm; +using namespace llvm_dialects; + +class VisitorIRTestFixture : public testing::Test { +protected: + void SetUp() override { + setupDialectsContext(); + makeModule(); + } + + LLVMContext Context; + std::unique_ptr DC; + std::unique_ptr Mod; + Function *EP = nullptr; + + BasicBlock *getEntryBlock() { return EntryBlock; } + +private: + BasicBlock *EntryBlock = nullptr; + + void makeModule() { + Mod = std::make_unique("dialects_test", Context); + const std::array Args = {Type::getInt32Ty(Mod->getContext())}; + FunctionCallee FC = Mod->getOrInsertFunction( + "main", + FunctionType::get(Type::getVoidTy(Mod->getContext()), Args, false)); + EP = cast(FC.getCallee()); + EntryBlock = BasicBlock::Create(Mod->getContext(), "entry", EP); + } + + void setupDialectsContext() { + DC = DialectContext::make(Context); + } +}; + +TEST_F(VisitorIRTestFixture, VisitOp) { + llvm_dialects::Builder Builder{Context}; + Builder.SetInsertPoint(getEntryBlock()); + + auto *Mul1 = Builder.create(); + auto *Mul2 = Builder.create(); + auto *Add1 = Builder.create(); + auto *Add2 = Builder.create(); + auto *DialectOp1 = Builder.create(); + + DenseSet Ops; + static const auto Visitor = + llvm_dialects::VisitorBuilder>() + .add([](auto &Ops, test::MulOp &Op) { Ops.insert(&Op); }) + .build(); + Visitor.visit(Ops, *Mod); + EXPECT_EQ(Ops.size(), 2); + for (const auto *Op : Ops) { + EXPECT_TRUE(Op == Mul1 || Op == Mul2); + } +} + +TEST_F(VisitorIRTestFixture, VisitOpClass) { + llvm_dialects::Builder Builder{Context}; + Builder.SetInsertPoint(getEntryBlock()); + + auto *Mul1 = Builder.create(); + auto *Mul2 = Builder.create(); + auto *Add1 = Builder.create(); + auto *Add2 = Builder.create(); + auto *DialectOp1 = Builder.create(); + + DenseSet Ops; + static const auto Visitor = + llvm_dialects::VisitorBuilder>() + .add( + [](auto &Ops, test::SomeBaseOpClass &Op) { Ops.insert(&Op); }) + .build(); + Visitor.visit(Ops, *Mod); + EXPECT_EQ(Ops.size(), 4); + for (const auto *Op : Ops) { + EXPECT_TRUE(Op == Mul1 || Op == Mul2 || Op == Add1 || Op == Add2); + } +}